From 9a4d0c6fe4170200ed9c952eb95efdeed9367a3b Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 2 Oct 2024 14:35:24 -0500 Subject: [PATCH 1/7] IdentityMapper: ensure CallWithKwargs has immutable kw_parameters --- pymbolic/mapper/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index 00d0dfc7..84c166b3 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -496,9 +496,9 @@ def map_call_with_kwargs(self, expr, *args, **kwargs): parameters = tuple([ self.rec(child, *args, **kwargs) for child in expr.parameters ]) - kw_parameters = { + kw_parameters = immutabledict({ key: self.rec(val, *args, **kwargs) - for key, val in expr.kw_parameters.items()} + for key, val in expr.kw_parameters.items()}) if (function is expr.function and all(child is orig_child for child, orig_child in From 1b89fa8c96b9703e78837a7e00179b92099fa98e Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 2 Oct 2024 14:36:47 -0500 Subject: [PATCH 2/7] CallWithKwargs: warn (don't fail) on non-hashable kw_parameters --- pymbolic/primitives.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index 838bad1a..57948ab5 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -38,6 +38,7 @@ ) from warnings import warn +from immutabledict import immutabledict from typing_extensions import TypeIs, dataclass_transform from . import traits @@ -1112,6 +1113,17 @@ class CallWithKwargs(AlgebraicLeaf): parameters: tuple[ExpressionT, ...] kw_parameters: Mapping[str, ExpressionT] + def __post_init__(self): + try: + hash(self.kw_parameters) + except Exception: + warn("CallWithKwargs created with non-hashable kw_parameters. " + "This is deprecated and will stop working in 2025. " + "If you need an immutable mapping, try the immutabledcit package.", + DeprecationWarning, stacklevel=3 + ) + object.__setattr__(self, "kw_parameters", immutabledict(self.kw_parameters)) + @expr_dataclass() class Subscript(AlgebraicLeaf): From e9bf103e62546077ab8f2828a5a0ef1ea16368d8 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 2 Oct 2024 15:58:46 -0500 Subject: [PATCH 3/7] Export Variable, Expression, and the type aliases from the pacakge root --- pymbolic/__init__.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/pymbolic/__init__.py b/pymbolic/__init__.py index 49470dd9..1d92fb7e 100644 --- a/pymbolic/__init__.py +++ b/pymbolic/__init__.py @@ -40,16 +40,20 @@ from .polynomial import Polynomial -from .primitives import Variable as var # noqa: N813 -from .primitives import variables -from .primitives import flattened_sum -from .primitives import subscript -from .primitives import flattened_product -from .primitives import quotient -from .primitives import linear_combination -from .primitives import make_common_subexpression as cse -from .primitives import make_sym_vector -from .primitives import disable_subscript_by_getitem +from .primitives import (Variable as var, # noqa: N813 + Variable, + Expression, + variables, + flattened_sum, + subscript, + flattened_product, + quotient, + linear_combination, + make_common_subexpression as cse, + make_sym_vector, + disable_subscript_by_getitem, + expr_dataclass, +) from .parser import parse from .mapper.evaluator import evaluate from .mapper.evaluator import evaluate_kw @@ -60,10 +64,18 @@ from .mapper.distributor import distribute as expand from .mapper.distributor import distribute from .mapper.flattener import flatten +from .typing import NumberT, ScalarT, ArithmeticExpressionT, ExpressionT, BoolT __all__ = ( + "ArithmeticExpressionT", + "BoolT", + "Expression", + "ExpressionT", + "NumberT", "Polynomial", + "ScalarT", + "Variable", "compile", "compiler", "cse", @@ -78,6 +90,7 @@ "evaluate_kw", "evaluator", "expand", + "expr_dataclass", "flatten", "flattened_product", "flattened_sum", From 9a2f899246acac415294cc8a55b44957ed8df64f Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 2 Oct 2024 16:00:54 -0500 Subject: [PATCH 4/7] Introduce ArithmeticExpressionT type alias So that one can keep doing arithmetic with the result of arithmetic without the type checker complaining --- pymbolic/primitives.py | 51 +++++++++++++++++++++++++----------------- pymbolic/typing.py | 1 + 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index 57948ab5..e2328610 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -42,7 +42,7 @@ from typing_extensions import TypeIs, dataclass_transform from . import traits -from .typing import ExpressionT, NumberT, ScalarT +from .typing import ArithmeticExpressionT, ExpressionT, NumberT, ScalarT if TYPE_CHECKING: @@ -324,8 +324,8 @@ def init_arg_names(self) -> tuple[str, ...]: # {{{ arithmetic - def __add__(self, other: object) -> ExpressionT: - if not is_valid_operand(other): + def __add__(self, other: object) -> ArithmeticExpressionT: + if not is_arithmetic_expression(other): return NotImplemented if is_nonzero(other): if self: @@ -338,8 +338,8 @@ def __add__(self, other: object) -> ExpressionT: else: return self - def __radd__(self, other: object) -> ExpressionT: - assert is_constant(other) + def __radd__(self, other: object) -> ArithmeticExpressionT: + assert is_number(other) if is_nonzero(other): if self: return Sum((other, self)) @@ -348,7 +348,7 @@ def __radd__(self, other: object) -> ExpressionT: else: return self - def __sub__(self, other: object) -> ExpressionT: + def __sub__(self, other: object) -> ArithmeticExpressionT: if not is_valid_operand(other): return NotImplemented @@ -357,7 +357,7 @@ def __sub__(self, other: object) -> ExpressionT: else: return self - def __rsub__(self, other: object) -> ExpressionT: + def __rsub__(self, other: object) -> ArithmeticExpressionT: if not is_constant(other): return NotImplemented @@ -366,7 +366,7 @@ def __rsub__(self, other: object) -> ExpressionT: else: return -self - def __mul__(self, other: object) -> ExpressionT: + def __mul__(self, other: object) -> ArithmeticExpressionT: if not is_valid_operand(other): return NotImplemented @@ -378,7 +378,7 @@ def __mul__(self, other: object) -> ExpressionT: else: return Product((self, other)) - def __rmul__(self, other: object) -> ExpressionT: + def __rmul__(self, other: object) -> ArithmeticExpressionT: if not is_constant(other): return NotImplemented @@ -389,7 +389,7 @@ def __rmul__(self, other: object) -> ExpressionT: else: return Product((other, self)) - def __div__(self, other: object) -> ExpressionT: + def __div__(self, other: object) -> ArithmeticExpressionT: if not is_valid_operand(other): return NotImplemented @@ -399,7 +399,7 @@ def __div__(self, other: object) -> ExpressionT: return quotient(self, other) __truediv__ = __div__ - def __rdiv__(self, other: object) -> ExpressionT: + def __rdiv__(self, other: object) -> ArithmeticExpressionT: if not is_valid_operand(other): return NotImplemented @@ -408,7 +408,7 @@ def __rdiv__(self, other: object) -> ExpressionT: return quotient(other, self) __rtruediv__ = __rdiv__ - def __floordiv__(self, other: object) -> ExpressionT: + def __floordiv__(self, other: object) -> ArithmeticExpressionT: if not is_valid_operand(other): return NotImplemented @@ -417,15 +417,15 @@ def __floordiv__(self, other: object) -> ExpressionT: return self return FloorDiv(self, other) - def __rfloordiv__(self, other: object) -> ExpressionT: - if not is_valid_operand(other): + def __rfloordiv__(self, other: object) -> ArithmeticExpressionT: + if not is_arithmetic_expression(other): return NotImplemented if is_zero(self-1): return other return FloorDiv(other, self) - def __mod__(self, other: object) -> ExpressionT: + def __mod__(self, other: object) -> ArithmeticExpressionT: if not is_valid_operand(other): return NotImplemented @@ -434,13 +434,13 @@ def __mod__(self, other: object) -> ExpressionT: return 0 return Remainder(self, other) - def __rmod__(self, other: object) -> ExpressionT: + def __rmod__(self, other: object) -> ArithmeticExpressionT: if not is_valid_operand(other): return NotImplemented return Remainder(other, self) - def __pow__(self, other: object) -> ExpressionT: + def __pow__(self, other: object) -> ArithmeticExpressionT: if not is_valid_operand(other): return NotImplemented @@ -451,7 +451,7 @@ def __pow__(self, other: object) -> ExpressionT: return self return Power(self, other) - def __rpow__(self, other: object) -> ExpressionT: + def __rpow__(self, other: object) -> ArithmeticExpressionT: assert is_constant(other) if is_zero(other): # base zero @@ -535,10 +535,10 @@ def __rand__(self, other: object) -> BitwiseAnd: # {{{ misc - def __neg__(self) -> ExpressionT: + def __neg__(self) -> ArithmeticExpressionT: return -1*self - def __pos__(self) -> ExpressionT: + def __pos__(self) -> ArithmeticExpressionT: return self def __call__(self, *args, **kwargs) -> Call | CallWithKwargs: @@ -1755,11 +1755,13 @@ def quotient(numerator, denominator): global VALID_CONSTANT_CLASSES global VALID_OPERANDS VALID_CONSTANT_CLASSES: tuple[type, ...] = (int, float, complex) +_BOOL_CLASSES: tuple[type, ...] = (bool,) VALID_OPERANDS = (Expression,) try: import numpy VALID_CONSTANT_CLASSES += (numpy.number, numpy.bool_) + _BOOL_CLASSES += (numpy.bool_, ) except ImportError: pass @@ -1768,10 +1770,19 @@ def is_constant(value: object) -> TypeIs[ScalarT]: return isinstance(value, VALID_CONSTANT_CLASSES) +def is_number(value: object) -> TypeIs[NumberT]: + return (not isinstance(value, _BOOL_CLASSES) + and isinstance(value, VALID_CONSTANT_CLASSES)) + + def is_valid_operand(value: object) -> TypeIs[ExpressionT]: return isinstance(value, VALID_OPERANDS) or is_constant(value) +def is_arithmetic_expression(value: object) -> TypeIs[ArithmeticExpressionT]: + return not isinstance(value, _BOOL_CLASSES) and is_valid_operand(value) + + def register_constant_class(class_): global VALID_CONSTANT_CLASSES diff --git a/pymbolic/typing.py b/pymbolic/typing.py index 46a40d4a..704d18e3 100644 --- a/pymbolic/typing.py +++ b/pymbolic/typing.py @@ -57,6 +57,7 @@ _ScalarOrExpression = Union[ScalarT, "Expression"] +ArithmeticExpressionT: TypeAlias = Union[NumberT, "Expression"] ExpressionT: TypeAlias = Union[_ScalarOrExpression, Tuple[_ScalarOrExpression, ...]] From 0ed2208d590bf180b5c2a930817167033504783f Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 2 Oct 2024 16:06:44 -0500 Subject: [PATCH 5/7] Doc improvements --- doc/conf.py | 6 + doc/primitives.rst | 1 + pymbolic/mapper/__init__.py | 8 +- pymbolic/mapper/stringifier.py | 6 +- pymbolic/primitives.py | 259 +++++++++++++++++++-------------- pymbolic/typing.py | 21 +++ 6 files changed, 186 insertions(+), 115 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 01231b13..5d2018a3 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -24,4 +24,10 @@ "sympy": ("https://docs.sympy.org/dev/", None), "typing_extensions": ("https://typing-extensions.readthedocs.io/en/latest/", None), + "immutabledict": + ("https://immutabledict.corenting.fr/", None) +} +autodoc_type_aliases = { + "ExpressionT": "ExpressionT", + "ArithmeticExpressionT": "ArithmeticExpressionT", } diff --git a/doc/primitives.rst b/doc/primitives.rst index aa6eb627..cddca023 100644 --- a/doc/primitives.rst +++ b/doc/primitives.rst @@ -1,6 +1,7 @@ Primitives (Basic Objects) ========================== +.. automodule:: pymbolic.typing .. automodule:: pymbolic.primitives .. vim: sw=4 diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index 84c166b3..8ad8c92d 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -49,7 +49,7 @@ .. rubric:: Handling objects that don't declare mapper methods In particular, this includes many non-subclasses of - :class:`pymbolic.primitives.Expression`. + :class:`pymbolic.Expression`. .. automethod:: map_foreign @@ -113,16 +113,16 @@ class UnsupportedExpressionError(ValueError): # {{{ mapper base class Mapper: - """A visitor for trees of :class:`pymbolic.primitives.Expression` + """A visitor for trees of :class:`pymbolic.Expression` subclasses. Each expression-derived object is dispatched to the - method named by the :attr:`pymbolic.primitives.Expression.mapper_method` + method named by the :attr:`pymbolic.Expression.mapper_method` attribute and if not found, the methods named by the class attribute *mapper_method* in the method resolution order of the object. """ def handle_unsupported_expression(self, expr, *args, **kwargs): """Mapper method that is invoked for - :class:`pymbolic.primitives.Expression` subclasses for which a mapper + :class:`pymbolic.Expression` subclasses for which a mapper method does not exist in this mapper. """ diff --git a/pymbolic/mapper/stringifier.py b/pymbolic/mapper/stringifier.py index 2310a49b..47e062c9 100644 --- a/pymbolic/mapper/stringifier.py +++ b/pymbolic/mapper/stringifier.py @@ -84,11 +84,11 @@ class StringifyMapper(Mapper): """A mapper to turn an expression tree into a string. - :class:`pymbolic.primitives.Expression.__str__` is often implemented using + :class:`pymbolic.Expression.__str__` is often implemented using this mapper. - When it encounters an unsupported :class:`pymbolic.primitives.Expression` - subclass, it calls its :meth:`pymbolic.primitives.Expression.make_stringifier` + When it encounters an unsupported :class:`pymbolic.Expression` + subclass, it calls its :meth:`pymbolic.Expression.make_stringifier` method to get a :class:`StringifyMapper` that potentially does. """ diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index e2328610..6a6bd4da 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -53,13 +53,9 @@ Expression base class --------------------- -.. autoclass:: Expression - -.. class:: ExpressionT +.. currentmodule:: pymbolic - A type that can be used in type annotations whenever an expression - is desired. A :class:`typing.Union` of :class:`Expression` and - built-in scalar types. +.. autoclass:: Expression .. autofunction:: expr_dataclass @@ -70,6 +66,8 @@ :undoc-members: :members: mapper_method +.. currentmodule:: pymbolic.primitives + .. autoclass:: Call :undoc-members: :members: mapper_method @@ -163,6 +161,11 @@ :undoc-members: :members: mapper_method +Slices +---------- + +.. autoclass:: Slice + Code generation helpers ----------------------- @@ -173,6 +176,14 @@ .. autoclass:: cse_scope .. autofunction:: make_common_subexpression +Symbolic derivatives and substitution +------------------------------------- + +Inspired by similar functionality in :mod:`sympy`. + +.. autoclass:: Substitution +.. autoclass:: Derivative + Helper functions ---------------- @@ -201,8 +212,14 @@ :undoc-members: :members: mapper_method -Outside references ------------------- +Helper classes +-------------- + +.. autoclass:: EmptyOK +.. autoclass:: _AttributeLookupCreator + +References +---------- .. class:: DataclassInstance @@ -223,6 +240,36 @@ .. class:: TypeIs See :data:`typing_extensions.TypeIs`. + +.. class:: Variable + + See :class:`pymbolic.Variable`. + +.. class:: ExpressionT + + See :class:`pymbolic.ExpressionT`. + +.. currentmodule:: pymbolic + +.. class:: Comparison + + See :class:`pymbolic.primitives.Comparison`. + +.. class:: LogicalNot + + See :class:`pymbolic.primitives.LogicalNot`. + +.. class:: LogicalAnd + + See :class:`pymbolic.primitives.LogicalAnd`. + +.. class:: LogicalOr + + See :class:`pymbolic.primitives.LogicalOr`. + +.. class:: Lookup + + See :class:`pymbolic.primitives.Lookup`. """ @@ -248,7 +295,11 @@ def __get__(self, owner_self, owner_cls): class _AttributeLookupCreator: - def __init__(self, aggregate: ExpressionT): + """Helper used by :attr:`pymbolic.Expression.a` to create lookups. + + .. automethod:: __getattr__ + """ + def __init__(self, aggregate: ExpressionT) -> None: self.aggregate = aggregate def __getattr__(self, name: str) -> Lookup: @@ -262,7 +313,8 @@ class EmptyOK: class Expression: """Superclass for parts of a mathematical expression. Overrides operators - to implicitly construct :class:`Sum`, :class:`Product` and other expressions. + to implicitly construct :class:`~pymbolic.primitives.Sum`, + :class:`~pymbolic.primitives.Product` and other expressions. Expression objects are immutable. @@ -271,16 +323,16 @@ class Expression: `PEP 634 `__-style pattern matching is now supported when Pymbolic is used under Python 3.10. - .. attribute:: a + .. autoproperty:: a - .. attribute:: attr + .. automethod:: attr - .. attribute:: mapper_method + .. autoattribute:: mapper_method The :class:`pymbolic.mapper.Mapper` method called for objects of this type. - .. method:: __getitem__ + .. automethod:: __getitem__ .. automethod:: make_stringifier @@ -1032,7 +1084,7 @@ class Leaf(AlgebraicLeaf): @expr_dataclass() class Variable(Leaf): """ - .. attribute:: name + .. autoattribute:: name """ name: str @@ -1074,44 +1126,40 @@ class FunctionSymbol(AlgebraicLeaf): class Call(AlgebraicLeaf): """A function invocation. - .. attribute:: function - - A :class:`Expression` that evaluates to a function. - - .. attribute:: parameters - - A :class:`tuple` of positional parameters, each element - of which is a :class:`Expression` or a constant. - + .. autoattribute:: function + .. autoattribute:: parameters """ function: ExpressionT + """A :class:`Expression` that evaluates to a function.""" + parameters: tuple[ExpressionT, ...] + """ + A :class:`tuple` of positional parameters, each element + of which is a :class:`Expression` or a constant. + """ @expr_dataclass() class CallWithKwargs(AlgebraicLeaf): """A function invocation with keyword arguments. - .. attribute:: function - - A :class:`Expression` that evaluates to a function. - - .. attribute:: parameters - - A :class:`tuple` of positional parameters, each element - of which is a :class:`Expression` or a constant. - - .. attribute:: kw_parameters - - A dictionary mapping names to arguments, , each - of which is a :class:`Expression` or a constant, - or an equivalent value accepted by the :class:`dict` - constructor. + .. autoattribute:: function + .. autoattribute:: parameters + .. autoattribute:: kw_parameters """ function: ExpressionT + """An :class:`Expression` that evaluates to a function.""" + parameters: tuple[ExpressionT, ...] + """A :class:`tuple` of positional parameters, each element + of which is a :class:`Expression` or a constant. + """ + kw_parameters: Mapping[str, ExpressionT] + """A dictionary mapping names to arguments, each + of which is a :class:`Expression` or a constant. + """ def __post_init__(self): try: @@ -1119,7 +1167,8 @@ def __post_init__(self): except Exception: warn("CallWithKwargs created with non-hashable kw_parameters. " "This is deprecated and will stop working in 2025. " - "If you need an immutable mapping, try the immutabledcit package.", + "If you need an immutable mapping, " + "try the :mod:`immutabledict` package.", DeprecationWarning, stacklevel=3 ) object.__setattr__(self, "kw_parameters", immutabledict(self.kw_parameters)) @@ -1128,13 +1177,6 @@ def __post_init__(self): @expr_dataclass() class Subscript(AlgebraicLeaf): """An array subscript. - - .. attribute:: aggregate - .. attribute:: index - .. attribute:: index_tuple - - Return :attr:`index` wrapped in a single-element tuple, if it is not already - a tuple. """ aggregate: ExpressionT @@ -1142,6 +1184,11 @@ class Subscript(AlgebraicLeaf): @property def index_tuple(self) -> tuple[ExpressionT, ...]: + """ + Return :attr:`index` wrapped in a single-element tuple, if it is not already + a tuple. + """ + if isinstance(self.index, tuple): return self.index else: @@ -1165,9 +1212,7 @@ class Lookup(AlgebraicLeaf): @expr_dataclass() class Sum(Expression): """ - .. attribute:: children - - A :class:`tuple`. + .. autoattribute:: children """ children: tuple[ExpressionT, ...] @@ -1215,9 +1260,7 @@ def __bool__(self): @expr_dataclass() class Product(Expression): """ - .. attribute:: children - - A :class:`tuple`. + .. autoattribute:: children """ children: tuple[ExpressionT, ...] @@ -1275,32 +1318,32 @@ def __bool__(self): @expr_dataclass() class Quotient(QuotientBase): """ - .. attribute:: numerator - .. attribute:: denominator + .. autoattribute:: numerator + .. autoattribute:: denominator """ @expr_dataclass() class FloorDiv(QuotientBase): """ - .. attribute:: numerator - .. attribute:: denominator + .. autoattribute:: numerator + .. autoattribute:: denominator """ @expr_dataclass() class Remainder(QuotientBase): """ - .. attribute:: numerator - .. attribute:: denominator + .. autoattribute:: numerator + .. autoattribute:: denominator """ @expr_dataclass() class Power(Expression): """ - .. attribute:: base - .. attribute:: exponent + .. autoattribute:: base + .. autoattribute:: exponent """ base: ExpressionT @@ -1320,16 +1363,16 @@ class _ShiftOperator(Expression): @expr_dataclass() class LeftShift(_ShiftOperator): """ - .. attribute:: shiftee - .. attribute:: shift + .. autoattribute:: shiftee + .. autoattribute:: shift """ @expr_dataclass() class RightShift(_ShiftOperator): """ - .. attribute:: shiftee - .. attribute:: shift + .. autoattribute:: shiftee + .. autoattribute:: shift """ # }}} @@ -1340,7 +1383,7 @@ class RightShift(_ShiftOperator): @expr_dataclass() class BitwiseNot(Expression): """ - .. attribute:: child + .. autoattribute:: child """ child: ExpressionT @@ -1349,9 +1392,7 @@ class BitwiseNot(Expression): @expr_dataclass() class BitwiseOr(Expression): """ - .. attribute:: children - - A :class:`tuple`. + .. autoattribute:: children """ children: tuple[ExpressionT, ...] @@ -1360,9 +1401,7 @@ class BitwiseOr(Expression): @expr_dataclass() class BitwiseXor(Expression): """ - .. attribute:: children - - A :class:`tuple`. + .. autoattribute:: children """ children: tuple[ExpressionT, ...] @@ -1370,9 +1409,7 @@ class BitwiseXor(Expression): @expr_dataclass() class BitwiseAnd(Expression): """ - .. attribute:: children - - A :class:`tuple`. + .. autoattribute:: children """ children: tuple[ExpressionT, ...] @@ -1384,24 +1421,24 @@ class BitwiseAnd(Expression): @expr_dataclass() class Comparison(Expression): """ - .. attribute:: left - .. attribute:: operator - - One of ``[">", ">=", "==", "!=", "<", "<="]``. - - .. attribute:: right + .. autoattribute:: left + .. autoattribute:: operator + .. autoattribute:: right .. note:: Unlike other expressions, comparisons are not implicitly constructed by - comparing :class:`Expression` objects. See :meth:`Expression.eq`. + comparing :class:`Expression` objects. See :meth:`pymbolic.Expression.eq`. - .. attribute:: operator_to_name - .. attribute:: name_to_operator + .. autoattribute:: operator_to_name + .. autoattribute:: name_to_operator """ left: ExpressionT + operator: str + """One of ``[">", ">=", "==", "!=", "<", "<="]``.""" + right: ExpressionT operator_to_name: ClassVar[dict[str, str]] = { @@ -1435,7 +1472,7 @@ def __post_init__(self): @expr_dataclass() class LogicalNot(Expression): """ - .. attribute:: child + .. autoattribute:: child """ child: ExpressionT @@ -1444,9 +1481,7 @@ class LogicalNot(Expression): @expr_dataclass() class LogicalOr(Expression): """ - .. attribute:: children - - A :class:`tuple`. + .. autoattribute:: children """ children: tuple[ExpressionT, ...] @@ -1455,9 +1490,7 @@ class LogicalOr(Expression): @expr_dataclass() class LogicalAnd(Expression): """ - .. attribute:: children - - A :class:`tuple`. + .. autoattribute:: children """ children: tuple[ExpressionT, ...] @@ -1465,9 +1498,9 @@ class LogicalAnd(Expression): @expr_dataclass() class If(Expression): """ - .. attribute:: condition - .. attribute:: then - .. attribute:: else_ + .. autoattribute:: condition + .. autoattribute:: then + .. autoattribute:: else_ """ condition: ExpressionT @@ -1478,9 +1511,7 @@ class If(Expression): @expr_dataclass() class Min(Expression): """ - .. attribute:: children - - A :class:`tuple`. + .. autoattribute:: children """ children: tuple[ExpressionT, ...] @@ -1488,9 +1519,7 @@ class Min(Expression): @expr_dataclass() class Max(Expression): """ - .. attribute:: children - - A :class:`tuple`. + .. autoattribute:: children """ children: tuple[ExpressionT, ...] @@ -1529,11 +1558,10 @@ class CommonSubexpression(Expression): should only be evaluated once. If, in code generation, it is assigned to a variable, a name starting with :attr:`prefix` should be used. - .. attribute:: child - .. attribute:: prefix - .. attribute:: scope + .. autoattribute:: child + .. autoattribute:: prefix + .. autoattribute:: scope - One of the values in :class:`cse_scope`. See there for meaning. See :class:`pymbolic.mapper.c_code.CCodeMapper` for an example. """ @@ -1541,6 +1569,9 @@ class CommonSubexpression(Expression): child: ExpressionT prefix: str | None = None scope: str = cse_scope.EVALUATION + """ + One of the values in :class:`cse_scope`. See there for meaning. + """ def __post_init__(self): if self.scope is None: @@ -1563,7 +1594,12 @@ def get_extra_properties(self): @expr_dataclass() class Substitution(Expression): - """Work-alike of sympy's Subs.""" + """Work-alike of :class:`~sympy.core.function.Subs`. + + .. autoattribute:: child + .. autoattribute:: variables + .. autoattribute:: values + """ child: ExpressionT variables: tuple[str, ...] @@ -1572,7 +1608,11 @@ class Substitution(Expression): @expr_dataclass() class Derivative(Expression): - """Work-alike of sympy's Derivative.""" + """Work-alike of sympy's :class:`~sympy.core.function.Derivative`. + + .. autoattribute:: child + .. autoattribute:: variables + """ child: ExpressionT variables: tuple[str, ...] @@ -1580,7 +1620,10 @@ class Derivative(Expression): @expr_dataclass() class Slice(Expression): - """A slice expression as in a[1:7].""" + """A slice expression as in a[1:7]. + + .. autoattribute:: children + """ children: (tuple[()] | tuple[ExpressionT] @@ -1630,7 +1673,7 @@ class NaN(Expression): to which ``np.nan == np.nan`` is *False*, but ``(np.nan,) == (np.nan,)`` is True. - .. attribute:: data_type + .. autoattribute:: data_type The data type used for the actual realization of the constant. Defaults to *None*. If given, This must be a callable to which a NaN diff --git a/pymbolic/typing.py b/pymbolic/typing.py index 704d18e3..00f02a28 100644 --- a/pymbolic/typing.py +++ b/pymbolic/typing.py @@ -1,3 +1,21 @@ +""" +.. currentmodule:: pymbolic + +Typing helpers +-------------- + +.. autoclass:: BoolT +.. autoclass:: NumberT +.. autoclass:: ScalarT +.. autoclass:: ArithmeticExpressionT + + A narrower type alias than :class:`ExpressionT` that is returned by + arithmetic operators, to allow continue doing arithmetic with the result + of arithmetic. + +.. autoclass:: ExpressionT +""" + from __future__ import annotations from typing import TYPE_CHECKING, Tuple, TypeVar, Union @@ -55,9 +73,12 @@ NumberT: TypeAlias = Union[IntegerT, InexactNumberT] ScalarT: TypeAlias = Union[NumberT, BoolT] +# FIXME: This only allows nesting tuples one-deep. When attempting to fix this, there +# are complaints about recursive type aliases. _ScalarOrExpression = Union[ScalarT, "Expression"] ArithmeticExpressionT: TypeAlias = Union[NumberT, "Expression"] + ExpressionT: TypeAlias = Union[_ScalarOrExpression, Tuple[_ScalarOrExpression, ...]] From a178ca50dd26d58b68a045e0d6f4ed798c204ad4 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 2 Oct 2024 20:50:37 -0500 Subject: [PATCH 6/7] Fix typing of recursive tuples of expressions --- pymbolic/typing.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pymbolic/typing.py b/pymbolic/typing.py index 00f02a28..debe5f28 100644 --- a/pymbolic/typing.py +++ b/pymbolic/typing.py @@ -73,13 +73,10 @@ NumberT: TypeAlias = Union[IntegerT, InexactNumberT] ScalarT: TypeAlias = Union[NumberT, BoolT] -# FIXME: This only allows nesting tuples one-deep. When attempting to fix this, there -# are complaints about recursive type aliases. - _ScalarOrExpression = Union[ScalarT, "Expression"] ArithmeticExpressionT: TypeAlias = Union[NumberT, "Expression"] -ExpressionT: TypeAlias = Union[_ScalarOrExpression, Tuple[_ScalarOrExpression, ...]] +ExpressionT: TypeAlias = Union[_ScalarOrExpression, Tuple["ExpressionT", ...]] T = TypeVar("T") From 1f7ee4e7640474c45b11b747fdd14bfdbd6fb3a5 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 3 Oct 2024 08:06:44 -0500 Subject: [PATCH 7/7] @expr_dataclass: Allow users to provide their own __hash__ --- pymbolic/primitives.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index 6a6bd4da..9ebe5575 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -878,6 +878,7 @@ def __iter__(self): def _augment_expression_dataclass( cls: type[DataclassInstance], + hash: bool, ) -> None: attr_tuple = ", ".join(f"self.{fld.name}" for fld in fields(cls)) if attr_tuple: @@ -946,7 +947,8 @@ def {cls.__name__}_hash(self): object.__setattr__(self, "_hash_value", hash_val) return hash_val - cls.__hash__ = {cls.__name__}_hash + if {hash}: + cls.__hash__ = {cls.__name__}_hash def {cls.__name__}_init_arg_names(self): @@ -1036,7 +1038,10 @@ def {cls.__name__}_setstate(self, state): @dataclass_transform(frozen_default=True) -def expr_dataclass(init: bool = True) -> Callable[[type[_T]], type[_T]]: +def expr_dataclass( + init: bool = True, + hash: bool = True + ) -> Callable[[type[_T]], type[_T]]: """A class decorator that makes the class a :func:`~dataclasses.dataclass` while also adding functionality needed for :class:`Expression` nodes. Specifically, it adds cached hashing, equality comparisons @@ -1052,12 +1057,15 @@ def expr_dataclass(init: bool = True) -> Callable[[type[_T]], type[_T]]: def map_cls(cls: type[_T]) -> type[_T]: # Frozen dataclasses (empirically) have a ~20% speed penalty in pymbolic, # and their frozen-ness is arguably a debug feature. - dc_cls = dataclass(init=init, frozen=__debug__, repr=False)(cls) + + # We provide __eq__/__hash__ below, don't redundantly generate it. + dc_cls = dataclass(init=init, eq=False, frozen=__debug__, repr=False)(cls) # FIXME: I'm not sure how to tell mypy that dc_cls is type[DataclassInstance] # It should just understand that? _augment_expression_dataclass( dc_cls, # type: ignore[arg-type] + hash=hash ) return dc_cls