From 31dbd3933a39dd0ac8302746262f43a98aef286f Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 20 Nov 2024 16:06:30 -0600 Subject: [PATCH] Rename type aliases for consistency --- doc/conf.py | 12 +- doc/index.rst | 5 +- doc/utilities.rst | 3 +- pymbolic/__init__.py | 85 +++++---- pymbolic/cse.py | 2 +- pymbolic/geometric_algebra/__init__.py | 6 +- pymbolic/geometric_algebra/mapper.py | 10 +- pymbolic/geometric_algebra/primitives.py | 8 +- pymbolic/interop/ast.py | 6 +- pymbolic/interop/matchpy/__init__.py | 30 +-- pymbolic/interop/matchpy/tofrom.py | 5 +- pymbolic/mapper/__init__.py | 139 +++++++------- pymbolic/mapper/coefficient.py | 10 +- pymbolic/mapper/collector.py | 26 +-- pymbolic/mapper/constant_folder.py | 24 +-- pymbolic/mapper/distributor.py | 12 +- pymbolic/mapper/evaluator.py | 10 +- pymbolic/mapper/flattener.py | 24 ++- pymbolic/mapper/stringifier.py | 22 +-- pymbolic/mapper/substitutor.py | 23 +-- pymbolic/parser.py | 4 +- pymbolic/primitives.py | 222 ++++++++++++----------- pymbolic/rational.py | 10 +- pymbolic/typing.py | 94 +++++++--- pyproject.toml | 2 +- test/test_pymbolic.py | 10 +- test/testlib.py | 4 +- 27 files changed, 448 insertions(+), 360 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index a3c1b0af..039bab0f 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -28,10 +28,18 @@ ("https://immutabledict.corenting.fr/", None) } autodoc_type_aliases = { - "ExpressionT": "ExpressionT", - "ArithmeticExpressionT": "ArithmeticExpressionT", + "Expression": "Expression", + "ArithmeticExpression": "ArithmeticExpression", } + +nitpick_ignore_regex = [ + # Avoids this error. Not sure where to even look. + # :1: WARNING: py:class reference target not found: ExpressionNode [ref.class] # noqa: E501 + ["py:class", r"ExpressionNode"], + ] + + import sys diff --git a/doc/index.rst b/doc/index.rst index 328ab814..43e0c765 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -69,10 +69,11 @@ You can also easily define your own objects to use inside an expression: .. doctest:: - >>> from pymbolic.primitives import Expression, expr_dataclass + >>> from pymbolic import ExpressionNode, expr_dataclass + >>> from pymbolic.typing import Expression >>> >>> @expr_dataclass() - ... class FancyOperator(Expression): + ... class FancyOperator(ExpressionNode): ... operand: Expression ... >>> u diff --git a/doc/utilities.rst b/doc/utilities.rst index 466795ba..27d9cacc 100644 --- a/doc/utilities.rst +++ b/doc/utilities.rst @@ -8,7 +8,8 @@ Parser .. function:: parse(expr_str) - Return a :class:`pymbolic.primitives.Expression` tree corresponding to *expr_str*. + Return a :class:`pymbolic.primitives.ExpressionNode` tree corresponding + to *expr_str*. The parser is also relatively easy to extend. See the source code of the following class. diff --git a/pymbolic/__init__.py b/pymbolic/__init__.py index 5ed3cb35..ca17dcaf 100644 --- a/pymbolic/__init__.py +++ b/pymbolic/__init__.py @@ -24,54 +24,60 @@ """ -from pymbolic.version import VERSION_TEXT as __version__ # noqa - -from . import parser -from . import compiler +from functools import partial -from .mapper import evaluator -from .mapper import stringifier -from .mapper import dependency -from .mapper import substitutor -from .mapper import differentiator -from .mapper import distributor -from .mapper import flattener -from . import primitives +from pytools import module_getattr_for_deprecations -from .primitives import (Variable as var, # noqa: N813 +from . import compiler, parser, primitives +from .compiler import compile +from .mapper import ( + dependency, + differentiator, + distributor, + evaluator, + flattener, + stringifier, + substitutor, +) +from .mapper.differentiator import differentiate, differentiate as diff +from .mapper.distributor import distribute, distribute as expand +from .mapper.evaluator import evaluate, evaluate_kw +from .mapper.flattener import flatten +from .mapper.substitutor import substitute +from .parser import parse +from .primitives import ( # noqa: N813 + ExpressionNode, Variable, - Expression, - variables, - flattened_sum, - subscript, + Variable as var, + disable_subscript_by_getitem, + expr_dataclass, flattened_product, - quotient, + flattened_sum, linear_combination, make_common_subexpression as cse, make_sym_vector, - disable_subscript_by_getitem, - expr_dataclass, + quotient, + subscript, + variables, ) -from .parser import parse -from .mapper.evaluator import evaluate -from .mapper.evaluator import evaluate_kw -from .compiler import compile -from .mapper.substitutor import substitute -from .mapper.differentiator import differentiate as diff -from .mapper.differentiator import differentiate -from .mapper.distributor import distribute as expand -from .mapper.distributor import distribute -from .mapper.flattener import flatten -from .typing import NumberT, ScalarT, ArithmeticExpressionT, ExpressionT, BoolT +from .typing import ( + ArithmeticExpression, + Bool, + Expression, + Expression as _TypingExpression, + Number, + Scalar, +) +from pymbolic.version import VERSION_TEXT as __version__ # noqa __all__ = ( - "ArithmeticExpressionT", - "BoolT", + "ArithmeticExpression", + "Bool", "Expression", - "ExpressionT", - "NumberT", - "ScalarT", + "ExpressionNode", + "Number", + "Scalar", "Variable", "compile", "compiler", @@ -105,3 +111,10 @@ "var", "variables", ) + +__getattr__ = partial(module_getattr_for_deprecations, __name__, { + "ExpressionT": ("pymbolic.typing.Expression", _TypingExpression, 2026), + "ArithmeticExpressionT": ("ArithmeticExpression", ArithmeticExpression, 2026), + "BoolT": ("Bool", Bool, 2026), + "ScalarT": ("Scalar", Scalar, 2026), + }) diff --git a/pymbolic/cse.py b/pymbolic/cse.py index bd1ac554..0d374dae 100644 --- a/pymbolic/cse.py +++ b/pymbolic/cse.py @@ -137,7 +137,7 @@ def tag_common_subexpressions(exprs): get_key = NormalizedKeyGetter() ucm = UseCountMapper(get_key) - if isinstance(exprs, prim.Expression): + if isinstance(exprs, prim.ExpressionNode): raise TypeError("exprs should be an iterable of expressions") for expr in exprs: diff --git a/pymbolic/geometric_algebra/__init__.py b/pymbolic/geometric_algebra/__init__.py index 605ef1e3..86bbd03e 100644 --- a/pymbolic/geometric_algebra/__init__.py +++ b/pymbolic/geometric_algebra/__init__.py @@ -33,7 +33,7 @@ from pytools import memoize, memoize_method from pymbolic.primitives import expr_dataclass, is_zero -from pymbolic.typing import ArithmeticExpressionT, T +from pymbolic.typing import ArithmeticExpression, T __doc__ = """ @@ -293,7 +293,7 @@ def get_euclidean_space(n: int) -> Space: # }}} -CoeffT = TypeVar("CoeffT", bound=ArithmeticExpressionT) +CoeffT = TypeVar("CoeffT", bound=ArithmeticExpression) # {{{ blade product weights @@ -428,7 +428,7 @@ def _cast_to_mv(obj: Any, space: Space) -> MultiVector: class MultiVector(Generic[CoeffT]): r"""An immutable multivector type. Its implementation follows [DFM]. It is pickleable, and not picky about what data is used as coefficients. - It supports :class:`pymbolic.primitives.Expression` objects of course, + It supports :class:`pymbolic.primitives.ExpressionNode` objects of course, but it can take just about any other scalar-ish coefficients. .. autoattribute:: data diff --git a/pymbolic/geometric_algebra/mapper.py b/pymbolic/geometric_algebra/mapper.py index d2508f9a..085a3ca9 100644 --- a/pymbolic/geometric_algebra/mapper.py +++ b/pymbolic/geometric_algebra/mapper.py @@ -49,21 +49,23 @@ PREC_NONE, StringifyMapper as StringifyMapperBase, ) -from pymbolic.primitives import Expression +from pymbolic.primitives import ExpressionNode class IdentityMapper(IdentityMapperBase[P]): def map_nabla( - self, expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs) -> Expression: + self, expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionNode: return expr def map_nabla_component(self, - expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs) -> Expression: + expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionNode: return expr def map_derivative_source(self, expr: prim.DerivativeSource, *args: P.args, **kwargs: P.kwargs - ) -> Expression: + ) -> ExpressionNode: operand = self.rec(expr.operand, *args, **kwargs) if operand is expr.operand: return expr diff --git a/pymbolic/geometric_algebra/primitives.py b/pymbolic/geometric_algebra/primitives.py index 00cf299c..47de49b1 100644 --- a/pymbolic/geometric_algebra/primitives.py +++ b/pymbolic/geometric_algebra/primitives.py @@ -29,8 +29,8 @@ from collections.abc import Hashable from typing import ClassVar -from pymbolic.primitives import Expression, Variable, expr_dataclass -from pymbolic.typing import ExpressionT +from pymbolic.primitives import ExpressionNode, Variable, expr_dataclass +from pymbolic.typing import Expression class MultiVectorVariable(Variable): @@ -39,7 +39,7 @@ class MultiVectorVariable(Variable): # {{{ geometric calculus -class _GeometricCalculusExpression(Expression): +class _GeometricCalculusExpression(ExpressionNode): def stringifier(self): from pymbolic.geometric_algebra.mapper import StringifyMapper return StringifyMapper @@ -58,7 +58,7 @@ class Nabla(_GeometricCalculusExpression): @expr_dataclass() class DerivativeSource(_GeometricCalculusExpression): - operand: ExpressionT + operand: Expression nabla_id: Hashable diff --git a/pymbolic/interop/ast.py b/pymbolic/interop/ast.py index 26df36bd..db4201e4 100644 --- a/pymbolic/interop/ast.py +++ b/pymbolic/interop/ast.py @@ -31,7 +31,7 @@ import pymbolic.primitives as p from pymbolic.mapper import CachedMapper -from pymbolic.typing import ExpressionT +from pymbolic.typing import Expression __doc__ = r''' @@ -263,7 +263,7 @@ def map_variable(self, expr) -> ast.expr: return ast.Name(id=expr.name) def _map_multi_children_op(self, - children: tuple[ExpressionT, ...], + children: tuple[Expression, ...], op_type: ast.operator) -> ast.expr: rec_children = [self.rec(child) for child in children] result = rec_children[-1] @@ -435,7 +435,7 @@ def to_python_ast(expr) -> ast.expr: return PymbolicToASTMapper()(expr) -def to_evaluatable_python_function(expr: ExpressionT, +def to_evaluatable_python_function(expr: Expression, fn_name: str ) -> str: """ diff --git a/pymbolic/interop/matchpy/__init__.py b/pymbolic/interop/matchpy/__init__.py index a7a9e912..0e1a3965 100644 --- a/pymbolic/interop/matchpy/__init__.py +++ b/pymbolic/interop/matchpy/__init__.py @@ -57,13 +57,13 @@ ) import pymbolic.primitives as p -from pymbolic.typing import ScalarT +from pymbolic.typing import Scalar as PbScalar ExprT: TypeAlias = Expression ConstantT = TypeVar("ConstantT") -ToMatchpyT = Callable[[p.Expression], ExprT] -FromMatchpyT = Callable[[ExprT], p.Expression] +ToMatchpyT = Callable[[p.ExpressionNode], ExprT] +FromMatchpyT = Callable[[ExprT], p.ExpressionNode] _NOT_OPERAND_METADATA = {"not_an_operand": True} @@ -95,7 +95,7 @@ def __lt__(self, other): @op_dataclass -class Scalar(_Constant[ScalarT]): +class Scalar(_Constant[PbScalar]): _mapper_method: str = "map_scalar" @@ -360,11 +360,11 @@ def _get_operand_at_path(expr: PymbolicOp, path: tuple[int, ...]) -> PymbolicOp: return result -def match(subject: p.Expression, - pattern: p.Expression, +def match(subject: p.ExpressionNode, + pattern: p.ExpressionNode, to_matchpy_expr: ToMatchpyT | None = None, from_matchpy_expr: FromMatchpyT | None = None - ) -> Iterator[Mapping[str, p.Expression | ScalarT]]: + ) -> Iterator[Mapping[str, p.ExpressionNode | PbScalar]]: from matchpy import Pattern, match from .tofrom import FromMatchpyExpressionMapper, ToMatchpyExpressionMapper @@ -383,12 +383,12 @@ def match(subject: p.Expression, for name, expr in subst.items()} -def match_anywhere(subject: p.Expression, - pattern: p.Expression, +def match_anywhere(subject: p.ExpressionNode, + pattern: p.ExpressionNode, to_matchpy_expr: ToMatchpyT | None = None, from_matchpy_expr: FromMatchpyT | None = None - ) -> Iterator[tuple[Mapping[str, p.Expression | ScalarT], - p.Expression | ScalarT] + ) -> Iterator[tuple[Mapping[str, p.ExpressionNode | PbScalar], + p.ExpressionNode | PbScalar] ]: from matchpy import Pattern, match_anywhere @@ -409,8 +409,8 @@ def match_anywhere(subject: p.Expression, from_matchpy_expr(_get_operand_at_path(m_subject, path))) -def make_replacement_rule(pattern: p.Expression, - replacement: Callable[..., p.Expression], +def make_replacement_rule(pattern: p.ExpressionNode, + replacement: Callable[..., p.ExpressionNode], to_matchpy_expr: ToMatchpyT | None = None, from_matchpy_expr: FromMatchpyT | None = None ) -> ReplacementRule: @@ -437,11 +437,11 @@ def make_replacement_rule(pattern: p.Expression, from_matchpy_expr)) -def replace_all(expression: p.Expression, +def replace_all(expression: p.ExpressionNode, rules: Iterable[ReplacementRule], to_matchpy_expr: ToMatchpyT | None = None, from_matchpy_expr: FromMatchpyT | None = None - ) -> p.Expression | tuple[p.Expression, ...]: + ) -> p.ExpressionNode | tuple[p.ExpressionNode, ...]: import collections.abc as abc from matchpy import replace_all diff --git a/pymbolic/interop/matchpy/tofrom.py b/pymbolic/interop/matchpy/tofrom.py index 63996881..db5e564a 100644 --- a/pymbolic/interop/matchpy/tofrom.py +++ b/pymbolic/interop/matchpy/tofrom.py @@ -12,6 +12,7 @@ import pymbolic.primitives as p from pymbolic.interop.matchpy.mapper import Mapper as BaseMatchPyMapper from pymbolic.mapper import Mapper as BasePymMapper +from pymbolic.typing import Scalar as PbScalar # {{{ to matchpy @@ -117,7 +118,7 @@ def map_star_wildcard(self, expr: p.StarWildcard) -> m.Wildcard: # {{{ from matchpy class FromMatchpyExpressionMapper(BaseMatchPyMapper): - def map_scalar(self, expr: m.Scalar) -> m.ScalarT: + def map_scalar(self, expr: m.Scalar) -> PbScalar: return expr.value def map_variable(self, expr: m.Variable) -> p.Variable: @@ -200,7 +201,7 @@ def map_if(self, expr: m.If) -> p.If: @dataclass(frozen=True, eq=True) class ToFromReplacement: - f: Callable[..., p.Expression] + f: Callable[..., p.ExpressionNode] to_matchpy_expr: m.ToMatchpyT from_matchpy_expr: m.FromMatchpyT diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index 25a4a8cb..b1fc6ee3 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -39,7 +39,7 @@ from typing_extensions import ParamSpec, TypeIs import pymbolic.primitives as p -from pymbolic.typing import ArithmeticExpressionT, ExpressionT +from pymbolic.typing import ArithmeticExpression, Expression if TYPE_CHECKING: @@ -71,7 +71,7 @@ .. rubric:: Handling objects that don't declare mapper methods In particular, this includes many non-subclasses of - :class:`pymbolic.Expression`. + :class:`pymbolic.ExpressionNode`. .. automethod:: map_foreign @@ -149,9 +149,9 @@ class UnsupportedExpressionError(ValueError): class Mapper(Generic[ResultT, P]): - """A visitor for trees of :class:`pymbolic.Expression` + """A visitor for trees of :class:`pymbolic.ExpressionNode` subclasses. Each expression-derived object is dispatched to the - method named by the :attr:`pymbolic.Expression.mapper_method` + method named by the :attr:`pymbolic.ExpressionNode.mapper_method` attribute and if not found, the methods named by the class attribute *mapper_method* in the method resolution order of the object. @@ -163,7 +163,7 @@ class Mapper(Generic[ResultT, P]): def handle_unsupported_expression(self, expr: object, *args: P.args, **kwargs: P.kwargs) -> ResultT: """Mapper method that is invoked for - :class:`pymbolic.Expression` subclasses for which a mapper + :class:`pymbolic.ExpressionNode` subclasses for which a mapper method does not exist in this mapper. """ @@ -172,7 +172,7 @@ def handle_unsupported_expression(self, type(self), type(expr))) def __call__(self, - expr: ExpressionT, *args: P.args, **kwargs: P.kwargs) -> ResultT: + expr: Expression, *args: P.args, **kwargs: P.kwargs) -> ResultT: """Dispatch *expr* to its corresponding mapper method. Pass on ``*args`` and ``**kwargs`` unmodified. @@ -190,7 +190,7 @@ def __call__(self, result = method(expr, *args, **kwargs) return result - if isinstance(expr, p.Expression): + if isinstance(expr, p.ExpressionNode): for cls in type(expr).__mro__[1:]: method_name = getattr(cls, "mapper_method", None) if method_name: @@ -206,7 +206,7 @@ def __call__(self, def rec_fallback(self, expr: object, *args: P.args, **kwargs: P.kwargs) -> ResultT: - if isinstance(expr, p.Expression): + if isinstance(expr, p.ExpressionNode): for cls in type(expr).__mro__[1:]: method_name = getattr(cls, "mapper_method", None) if method_name: @@ -289,11 +289,11 @@ def map_max(self, raise NotImplementedError def map_list(self, - expr: list[ExpressionT], *args: P.args, **kwargs: P.kwargs) -> ResultT: + expr: list[Expression], *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError def map_tuple(self, - expr: tuple[ExpressionT, ...], + expr: tuple[Expression, ...], *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError @@ -396,7 +396,7 @@ def __init__(self) -> None: Mapper.__init__(self) def get_cache_key(self, - expr: ExpressionT, + expr: Expression, *args: P.args, **kwargs: P.kwargs ) -> CacheKeyT: @@ -416,7 +416,7 @@ def get_cache_key(self, return (type(expr), expr, args, immutabledict(kwargs)) def __call__(self, - expr: ExpressionT, + expr: Expression, *args: P.args, **kwargs: P.kwargs ) -> ResultT: @@ -429,7 +429,7 @@ def __call__(self, method_name = getattr(expr, "mapper_method", None) if method_name is not None: method = cast( - Callable[Concatenate[ExpressionT, P], ResultT] | None, + Callable[Concatenate[Expression, P], ResultT] | None, getattr(self, method_name, None) ) if method is not None: @@ -592,12 +592,12 @@ def map_min(self, for child in expr.children) def map_tuple(self, - expr: tuple[ExpressionT, ...], *args: P.args, **kwargs: P.kwargs + expr: tuple[Expression, ...], *args: P.args, **kwargs: P.kwargs ) -> ResultT: return self.combine(self.rec(child, *args, **kwargs) for child in expr) def map_list(self, - expr: list[ExpressionT], *args: P.args, **kwargs: P.kwargs + expr: list[Expression], *args: P.args, **kwargs: P.kwargs ) -> ResultT: return self.combine(self.rec(child, *args, **kwargs) for child in expr) @@ -607,7 +607,7 @@ def map_numpy_array(self, return self.combine(self.rec(el, *args, **kwargs) for el in expr.flat) def map_multivector(self, - expr: MultiVector[ArithmeticExpressionT], + expr: MultiVector[ArithmeticExpression], *args: P.args, **kwargs: P.kwargs ) -> ResultT: return self.combine( @@ -688,7 +688,7 @@ class CachedCollector(CachedMapper, Collector): # {{{ identity mapper -class IdentityMapper(Mapper[ExpressionT, P]): +class IdentityMapper(Mapper[Expression, P]): """A :class:`Mapper` whose default mapper methods make a deep copy of each subexpression. @@ -699,48 +699,48 @@ class IdentityMapper(Mapper[ExpressionT, P]): """ def rec_arith(self, - expr: ArithmeticExpressionT, *args: P.args, **kwargs: P.kwargs - ) -> ArithmeticExpressionT: + expr: ArithmeticExpression, *args: P.args, **kwargs: P.kwargs + ) -> ArithmeticExpression: res = self.rec(expr, *args, **kwargs) assert p.is_arithmetic_expression(res) return res def map_constant(self, expr: object, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: # leaf -- no need to rebuild assert p.is_valid_operand(expr) return expr def map_variable(self, expr: p.Variable, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: # leaf -- no need to rebuild return expr def map_wildcard(self, expr: p.Wildcard, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: return expr def map_dot_wildcard(self, expr: p.DotWildcard, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: return expr def map_star_wildcard(self, expr: p.StarWildcard, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: return expr def map_function_symbol(self, expr: p.FunctionSymbol, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: return expr def map_call(self, expr: p.Call, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: function = self.rec(expr.function, *args, **kwargs) parameters = tuple([ self.rec(child, *args, **kwargs) for child in expr.parameters @@ -755,12 +755,12 @@ def map_call(self, def map_call_with_kwargs(self, expr: p.CallWithKwargs, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: function = self.rec(expr.function, *args, **kwargs) parameters = tuple([ self.rec(child, *args, **kwargs) for child in expr.parameters ]) - kw_parameters: Mapping[str, ExpressionT] = immutabledict({ + kw_parameters: Mapping[str, Expression] = immutabledict({ key: self.rec(val, *args, **kwargs) for key, val in expr.kw_parameters.items()}) @@ -774,7 +774,7 @@ def map_call_with_kwargs(self, def map_subscript(self, expr: p.Subscript, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: aggregate = self.rec(expr.aggregate, *args, **kwargs) index = self.rec(expr.index, *args, **kwargs) if aggregate is expr.aggregate and index is expr.index: @@ -783,7 +783,7 @@ def map_subscript(self, def map_lookup(self, expr: p.Lookup, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: aggregate = self.rec(expr.aggregate, *args, **kwargs) if aggregate is expr.aggregate: return expr @@ -791,7 +791,7 @@ def map_lookup(self, def map_sum(self, expr: p.Sum, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: children = [self.rec(child, *args, **kwargs) for child in expr.children] if all(child is orig_child for child, orig_child in zip(children, expr.children, strict=True)): @@ -801,7 +801,7 @@ def map_sum(self, def map_product(self, expr: p.Product, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: children = [self.rec(child, *args, **kwargs) for child in expr.children] if all(child is orig_child for child, orig_child in zip(children, expr.children, strict=True)): @@ -811,7 +811,7 @@ def map_product(self, def map_quotient(self, expr: p.Quotient, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: numerator = self.rec_arith(expr.numerator, *args, **kwargs) denominator = self.rec_arith(expr.denominator, *args, **kwargs) if numerator is expr.numerator and denominator is expr.denominator: @@ -820,7 +820,7 @@ def map_quotient(self, def map_floor_div(self, expr: p.FloorDiv, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: numerator = self.rec_arith(expr.numerator, *args, **kwargs) denominator = self.rec_arith(expr.denominator, *args, **kwargs) if numerator is expr.numerator and denominator is expr.denominator: @@ -829,7 +829,7 @@ def map_floor_div(self, def map_remainder(self, expr: p.Remainder, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: numerator = self.rec_arith(expr.numerator, *args, **kwargs) denominator = self.rec_arith(expr.denominator, *args, **kwargs) if numerator is expr.numerator and denominator is expr.denominator: @@ -838,7 +838,7 @@ def map_remainder(self, def map_power(self, expr: p.Power, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: base = self.rec_arith(expr.base, *args, **kwargs) exponent = self.rec_arith(expr.exponent, *args, **kwargs) if base is expr.base and exponent is expr.exponent: @@ -847,7 +847,7 @@ def map_power(self, def map_left_shift(self, expr: p.LeftShift, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: shiftee = self.rec(expr.shiftee, *args, **kwargs) shift = self.rec(expr.shift, *args, **kwargs) if shiftee is expr.shiftee and shift is expr.shift: @@ -856,7 +856,7 @@ def map_left_shift(self, def map_right_shift(self, expr: p.RightShift, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: shiftee = self.rec(expr.shiftee, *args, **kwargs) shift = self.rec(expr.shift, *args, **kwargs) if shiftee is expr.shiftee and shift is expr.shift: @@ -865,7 +865,7 @@ def map_right_shift(self, def map_bitwise_not(self, expr: p.BitwiseNot, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: child = self.rec(expr.child, *args, **kwargs) if child is expr.child: return expr @@ -873,7 +873,7 @@ def map_bitwise_not(self, def map_bitwise_or(self, expr: p.BitwiseOr, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: children = [self.rec(child, *args, **kwargs) for child in expr.children] if all(child is orig_child for child, orig_child in zip(children, expr.children, strict=True)): @@ -883,7 +883,7 @@ def map_bitwise_or(self, def map_bitwise_and(self, expr: p.BitwiseAnd, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: children = [self.rec(child, *args, **kwargs) for child in expr.children] if all(child is orig_child for child, orig_child in zip(children, expr.children, strict=True)): @@ -893,7 +893,7 @@ def map_bitwise_and(self, def map_bitwise_xor(self, expr: p.BitwiseXor, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: children = [self.rec(child, *args, **kwargs) for child in expr.children] if all(child is orig_child for child, orig_child in zip(children, expr.children, strict=True)): @@ -903,7 +903,7 @@ def map_bitwise_xor(self, def map_logical_not(self, expr: p.LogicalNot, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: child = self.rec(expr.child, *args, **kwargs) if child is expr.child: return expr @@ -911,7 +911,7 @@ def map_logical_not(self, def map_logical_or(self, expr: p.LogicalOr, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: children = [self.rec(child, *args, **kwargs) for child in expr.children] if all(child is orig_child for child, orig_child in zip(children, expr.children, strict=True)): @@ -921,7 +921,7 @@ def map_logical_or(self, def map_logical_and(self, expr: p.LogicalAnd, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: children = [self.rec(child, *args, **kwargs) for child in expr.children] if all(child is orig_child for child, orig_child in zip(children, expr.children, strict=True)): @@ -931,7 +931,7 @@ def map_logical_and(self, def map_comparison(self, expr: p.Comparison, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: left = self.rec(expr.left, *args, **kwargs) right = self.rec(expr.right, *args, **kwargs) if left is expr.left and right is expr.right: @@ -940,15 +940,15 @@ def map_comparison(self, return type(expr)(left, expr.operator, right) def map_list(self, - expr: list[ExpressionT], *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + expr: list[Expression], *args: P.args, **kwargs: P.kwargs + ) -> Expression: # True fact: lists aren't expressions return [self.rec(child, *args, **kwargs) for child in expr] # type: ignore[return-value] def map_tuple(self, - expr: tuple[ExpressionT, ...], *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + expr: tuple[Expression, ...], *args: P.args, **kwargs: P.kwargs + ) -> Expression: children = [self.rec(child, *args, **kwargs) for child in expr] if all(child is orig_child for child, orig_child in zip(children, expr, strict=True)): @@ -958,7 +958,7 @@ def map_tuple(self, def map_numpy_array(self, expr: np.ndarray, *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: import numpy result = numpy.empty(expr.shape, dtype=object) @@ -969,16 +969,16 @@ def map_numpy_array(self, return result # type: ignore[return-value] def map_multivector(self, - expr: MultiVector[ArithmeticExpressionT], + expr: MultiVector[ArithmeticExpression], *args: P.args, **kwargs: P.kwargs - ) -> ExpressionT: + ) -> Expression: # True fact: MultiVectors aren't expressions - return expr.map(lambda ch: cast(ArithmeticExpressionT, + return expr.map(lambda ch: cast(ArithmeticExpression, self.rec(ch, *args, **kwargs))) # type: ignore[return-value] def map_common_subexpression(self, expr: p.CommonSubexpression, - *args: P.args, **kwargs: P.kwargs) -> ExpressionT: + *args: P.args, **kwargs: P.kwargs) -> Expression: result = self.rec(expr.child, *args, **kwargs) if result is expr.child: return expr @@ -991,7 +991,7 @@ def map_common_subexpression(self, def map_substitution(self, expr: p.Substitution, - *args: P.args, **kwargs: P.kwargs) -> ExpressionT: + *args: P.args, **kwargs: P.kwargs) -> Expression: child = self.rec(expr.child, *args, **kwargs) values = tuple([self.rec(v, *args, **kwargs) for v in expr.values]) if child is expr.child and all(val is orig_val @@ -1002,7 +1002,7 @@ def map_substitution(self, def map_derivative(self, expr: p.Derivative, - *args: P.args, **kwargs: P.kwargs) -> ExpressionT: + *args: P.args, **kwargs: P.kwargs) -> Expression: child = self.rec(expr.child, *args, **kwargs) if child is expr.child: return expr @@ -1011,7 +1011,7 @@ def map_derivative(self, def map_slice(self, expr: p.Slice, - *args: P.args, **kwargs: P.kwargs) -> ExpressionT: + *args: P.args, **kwargs: P.kwargs) -> Expression: children: p.SliceChildrenT = cast(p.SliceChildrenT, tuple([ None if child is None else self.rec(child, *args, **kwargs) for child in expr.children @@ -1022,7 +1022,7 @@ def map_slice(self, return type(expr)(children) - def map_if(self, expr: p.If, *args: P.args, **kwargs: P.kwargs) -> ExpressionT: + def map_if(self, expr: p.If, *args: P.args, **kwargs: P.kwargs) -> Expression: condition = self.rec(expr.condition, *args, **kwargs) then = self.rec(expr.then, *args, **kwargs) else_ = self.rec(expr.else_, *args, **kwargs) @@ -1033,7 +1033,8 @@ def map_if(self, expr: p.If, *args: P.args, **kwargs: P.kwargs) -> ExpressionT: return type(expr)(condition, then, else_) - def map_min(self, expr: p.Min, *args: P.args, **kwargs: P.kwargs) -> ExpressionT: + def map_min(self, + expr: p.Min, *args: P.args, **kwargs: P.kwargs) -> Expression: children = tuple([ self.rec(child, *args, **kwargs) for child in expr.children ]) @@ -1043,7 +1044,8 @@ def map_min(self, expr: p.Min, *args: P.args, **kwargs: P.kwargs) -> ExpressionT return type(expr)(children) - def map_max(self, expr: p.Max, *args: P.args, **kwargs: P.kwargs) -> ExpressionT: + def map_max(self, + expr: p.Max, *args: P.args, **kwargs: P.kwargs) -> Expression: children = tuple([ self.rec(child, *args, **kwargs) for child in expr.children ]) @@ -1053,12 +1055,13 @@ def map_max(self, expr: p.Max, *args: P.args, **kwargs: P.kwargs) -> ExpressionT return type(expr)(children) - def map_nan(self, expr: p.NaN, *args: P.args, **kwargs: P.kwargs) -> ExpressionT: + def map_nan(self, + expr: p.NaN, *args: P.args, **kwargs: P.kwargs) -> Expression: # Leaf node -- don't recurse return expr -class CachedIdentityMapper(CachedMapper[ExpressionT, P], IdentityMapper[P]): +class CachedIdentityMapper(CachedMapper[Expression, P], IdentityMapper[P]): pass # }}} @@ -1217,7 +1220,8 @@ def map_power(self, expr: p.Power, *args: P.args, **kwargs: P.kwargs) -> None: self.post_visit(expr, *args, **kwargs) def map_tuple(self, - expr: tuple[ExpressionT, ...], *args: P.args, **kwargs: P.kwargs) -> None: + expr: tuple[Expression, ...], *args: P.args, **kwargs: P.kwargs + ) -> None: if not self.visit(expr, *args, **kwargs): return @@ -1238,7 +1242,7 @@ def map_numpy_array(self, self.post_visit(expr, *args, **kwargs) def map_multivector(self, - expr: MultiVector[ArithmeticExpressionT], + expr: MultiVector[ArithmeticExpression], *args: P.args, **kwargs: P.kwargs) -> None: if not self.visit(expr, *args, **kwargs): return @@ -1502,7 +1506,7 @@ class CSECachingMapperMixin(ABC, Generic[ResultT, P]): This method deliberately does not support extra arguments in mapper dispatch, to avoid spurious dependencies of the cache on these arguments. """ - _cse_cache_dict: dict[tuple[ExpressionT, P.args, P.kwargs], ResultT] + _cse_cache_dict: dict[tuple[Expression, P.args, P.kwargs], ResultT] def map_common_subexpression(self, expr: p.CommonSubexpression, @@ -1512,7 +1516,8 @@ def map_common_subexpression(self, except AttributeError: ccd = self._cse_cache_dict = {} - key: tuple[ExpressionT, P.args, P.kwargs] = (expr, args, immutabledict(kwargs)) + key: tuple[Expression, P.args, P.kwargs] = ( + expr, args, immutabledict(kwargs)) try: return ccd[key] except KeyError: diff --git a/pymbolic/mapper/coefficient.py b/pymbolic/mapper/coefficient.py index 6a7feb55..72315f98 100644 --- a/pymbolic/mapper/coefficient.py +++ b/pymbolic/mapper/coefficient.py @@ -28,10 +28,10 @@ import pymbolic.primitives as p from pymbolic.mapper import Mapper -from pymbolic.typing import ArithmeticExpressionT +from pymbolic.typing import ArithmeticExpression -CoeffsT: TypeAlias = Mapping[p.AlgebraicLeaf | Literal[1], ArithmeticExpressionT] +CoeffsT: TypeAlias = Mapping[p.AlgebraicLeaf | Literal[1], ArithmeticExpression] class CoefficientCollector(Mapper[CoeffsT, []]): @@ -41,7 +41,7 @@ def __init__(self, target_names: Collection[str] | None = None) -> None: def map_sum(self, expr: p.Sum) -> CoeffsT: stride_dicts = [self.rec(ch) for ch in expr.children] - result: dict[p.AlgebraicLeaf | Literal[1], ArithmeticExpressionT] = {} + result: dict[p.AlgebraicLeaf | Literal[1], ArithmeticExpression] = {} for stride_dict in stride_dicts: for var, stride in stride_dict.items(): if var in result: @@ -64,11 +64,11 @@ def map_product(self, expr: p.Product) -> CoeffsT: "nonlinear expression") idx_of_child_with_vars = i - other_coeffs: ArithmeticExpressionT = 1 + other_coeffs: ArithmeticExpression = 1 for i, child_coeffs in enumerate(children_coeffs): if i != idx_of_child_with_vars: assert len(child_coeffs) == 1 - other_coeffs *= cast(ArithmeticExpressionT, child_coeffs[1]) + other_coeffs *= cast(ArithmeticExpression, child_coeffs[1]) if idx_of_child_with_vars is None: return {1: other_coeffs} diff --git a/pymbolic/mapper/collector.py b/pymbolic/mapper/collector.py index 5b211948..c6c89eff 100644 --- a/pymbolic/mapper/collector.py +++ b/pymbolic/mapper/collector.py @@ -33,7 +33,7 @@ import pymbolic.primitives as p from pymbolic.mapper import IdentityMapper from pymbolic.mapper.dependency import DependenciesT -from pymbolic.typing import ArithmeticExpressionT, ExpressionT +from pymbolic.typing import ArithmeticExpression, Expression class TermCollector(IdentityMapper[[]]): @@ -49,13 +49,13 @@ def __init__(self, parameters: Set[p.AlgebraicLeaf] | None = None): parameters = set() self.parameters = parameters - def get_dependencies(self, expr: ExpressionT) -> DependenciesT: + def get_dependencies(self, expr: Expression) -> DependenciesT: from pymbolic.mapper.dependency import DependencyMapper return DependencyMapper()(expr) - def split_term(self, mul_term: ExpressionT) -> tuple[ - Set[tuple[ArithmeticExpressionT, ArithmeticExpressionT]], - ArithmeticExpressionT + def split_term(self, mul_term: Expression) -> tuple[ + Set[tuple[ArithmeticExpression, ArithmeticExpression]], + ArithmeticExpression ]: """Returns a pair consisting of: - a frozenset of (base, exponent) pairs @@ -67,21 +67,21 @@ def split_term(self, mul_term: ExpressionT) -> tuple[ """ from pymbolic.primitives import AlgebraicLeaf, Power, Product - def base(term: ExpressionT) -> ArithmeticExpressionT: + def base(term: Expression) -> ArithmeticExpression: if isinstance(term, Power): return term.base else: assert p.is_arithmetic_expression(term) return term - def exponent(term: ExpressionT) -> ArithmeticExpressionT: + def exponent(term: Expression) -> ArithmeticExpression: if isinstance(term, Power): return term.exponent else: return 1 if isinstance(mul_term, Product): - terms: Sequence[ExpressionT] = mul_term.children + terms: Sequence[Expression] = mul_term.children elif isinstance(mul_term, Power | AlgebraicLeaf): terms = [mul_term] elif not bool(self.get_dependencies(mul_term)): @@ -89,7 +89,7 @@ def exponent(term: ExpressionT) -> ArithmeticExpressionT: else: raise RuntimeError("split_term expects a multiplicative term") - base2exp: dict[ArithmeticExpressionT, ArithmeticExpressionT] = {} + base2exp: dict[ArithmeticExpression, ArithmeticExpression] = {} for term in terms: mybase = base(term) myexp = exponent(term) @@ -110,13 +110,13 @@ def exponent(term: ExpressionT) -> ArithmeticExpressionT: base_exp_set = frozenset( (base, exp) for base, exp in cleaned_base2exp.items()) - return base_exp_set, cast(ArithmeticExpressionT, + return base_exp_set, cast(ArithmeticExpression, self.rec(pymbolic.flattened_product(coefficients))) - def map_sum(self, expr: p.Sum) -> ExpressionT: + def map_sum(self, expr: p.Sum) -> Expression: term2coeff: dict[ - Set[tuple[ArithmeticExpressionT, ArithmeticExpressionT]], - ArithmeticExpressionT] = {} + Set[tuple[ArithmeticExpression, ArithmeticExpression]], + ArithmeticExpression] = {} for child in expr.children: term, coeff = self.split_term(child) term2coeff[term] = term2coeff.get(term, 0) + coeff diff --git a/pymbolic/mapper/constant_folder.py b/pymbolic/mapper/constant_folder.py index e68971ab..70f3b1b7 100644 --- a/pymbolic/mapper/constant_folder.py +++ b/pymbolic/mapper/constant_folder.py @@ -35,10 +35,10 @@ Mapper, ) from pymbolic.primitives import Product, Sum, is_arithmetic_expression -from pymbolic.typing import ArithmeticExpressionT, ExpressionT +from pymbolic.typing import ArithmeticExpression, Expression -class ConstantFoldingMapperBase(Mapper[ExpressionT, []]): +class ConstantFoldingMapperBase(Mapper[Expression, []]): def is_constant(self, expr): from pymbolic.mapper.dependency import DependencyMapper return not bool(DependencyMapper()(expr)) @@ -53,16 +53,16 @@ def evaluate(self, expr): def fold(self, expr: Sum | Product, op: Callable[ - [ArithmeticExpressionT, ArithmeticExpressionT], - ArithmeticExpressionT], + [ArithmeticExpression, ArithmeticExpression], + ArithmeticExpression], constructor: Callable[ - [tuple[ArithmeticExpressionT, ...]], - ArithmeticExpressionT], - ) -> ExpressionT: + [tuple[ArithmeticExpression, ...]], + ArithmeticExpression], + ) -> Expression: klass = type(expr) - constants: list[ArithmeticExpressionT] = [] - nonconstants: list[ArithmeticExpressionT] = [] + constants: list[ArithmeticExpression] = [] + nonconstants: list[ArithmeticExpression] = [] queue = list(expr.children) while queue: @@ -90,7 +90,7 @@ def fold(self, else: return constructor(tuple(nonconstants)) - def map_sum(self, expr: Sum) -> ExpressionT: + def map_sum(self, expr: Sum) -> Expression: import operator from pymbolic.primitives import flattened_sum @@ -108,7 +108,7 @@ def map_product(self, expr): class ConstantFoldingMapper( - CSECachingMapperMixin[ExpressionT, []], + CSECachingMapperMixin[Expression, []], ConstantFoldingMapperBase, IdentityMapper[[]]): @@ -117,7 +117,7 @@ class ConstantFoldingMapper( class CommutativeConstantFoldingMapper( - CSECachingMapperMixin[ExpressionT, []], + CSECachingMapperMixin[Expression, []], CommutativeConstantFoldingMapperBase, IdentityMapper[[]]): diff --git a/pymbolic/mapper/distributor.py b/pymbolic/mapper/distributor.py index 85d13504..b8e33d5c 100644 --- a/pymbolic/mapper/distributor.py +++ b/pymbolic/mapper/distributor.py @@ -34,7 +34,7 @@ from pymbolic.mapper import IdentityMapper from pymbolic.mapper.collector import TermCollector from pymbolic.mapper.constant_folder import CommutativeConstantFoldingMapper -from pymbolic.typing import ArithmeticExpressionT, ExpressionT +from pymbolic.typing import ArithmeticExpression, Expression class DistributeMapper(IdentityMapper[[]]): @@ -69,7 +69,7 @@ def map_sum(self, expr): else: return res - def map_product(self, expr: p.Product) -> ExpressionT: + def map_product(self, expr: p.Product) -> Expression: def dist(prod): if not isinstance(prod, p.Product): return prod @@ -112,13 +112,13 @@ def map_quotient(self, expr): self.rec(expr.numerator) ]) - def map_power(self, expr: p.Power) -> ExpressionT: + def map_power(self, expr: p.Power) -> Expression: from pymbolic.primitives import Sum newbase = self.rec(expr.base) if isinstance(newbase, p.Product): return self.rec(pymbolic.flattened_product([ - cast(ArithmeticExpressionT, child)**expr.exponent + cast(ArithmeticExpression, child)**expr.exponent for child in newbase.children ])) @@ -133,7 +133,9 @@ def map_power(self, expr: p.Power) -> ExpressionT: return IdentityMapper.map_power(self, expr) -def distribute(expr: ExpressionT, parameters=None, commutative=True) -> ExpressionT: +def distribute( + expr: Expression, parameters=None, commutative=True + ) -> Expression: if parameters is None: parameters = frozenset() if commutative: diff --git a/pymbolic/mapper/evaluator.py b/pymbolic/mapper/evaluator.py index bdff4c2d..d38a395c 100644 --- a/pymbolic/mapper/evaluator.py +++ b/pymbolic/mapper/evaluator.py @@ -40,7 +40,7 @@ import pymbolic.primitives as p from pymbolic.mapper import CachedMapper, CSECachingMapperMixin, Mapper, ResultT -from pymbolic.typing import ExpressionT +from pymbolic.typing import Expression if TYPE_CHECKING: @@ -155,7 +155,7 @@ def map_logical_or(self, expr: p.LogicalOr) -> bool: # type: ignore[override] def map_logical_and(self, expr: p.LogicalAnd) -> bool: # type: ignore[override] return all(self.rec(ch) for ch in expr.children) - def map_list(self, expr: list[ExpressionT]) -> ResultT: + def map_list(self, expr: list[Expression]) -> ResultT: return [self.rec(child) for child in expr] # type: ignore[return-value] def map_numpy_array(self, expr: np.ndarray) -> ResultT: @@ -188,7 +188,7 @@ def map_min(self, expr: p.Min) -> ResultT: def map_max(self, expr: p.Max) -> ResultT: return max(self.rec(child) for child in expr.children) # type: ignore[type-var] - def map_tuple(self, expr: tuple[ExpressionT, ...]) -> ResultT: + def map_tuple(self, expr: tuple[Expression, ...]) -> ResultT: return tuple([self.rec(child) for child in expr]) # type: ignore[return-value] def map_nan(self, expr: p.NaN) -> ResultT: @@ -222,7 +222,7 @@ def map_rational(self, expr) -> float: def evaluate( - expression: ExpressionT, + expression: Expression, context: Mapping[str, ResultT] | None = None, mapper_cls: type[EvaluationMapper[ResultT]] = CachedEvaluationMapper, ) -> ResultT: @@ -236,7 +236,7 @@ def evaluate( def evaluate_kw( - expression: ExpressionT, + expression: Expression, mapper_cls: type[EvaluationMapper[ResultT]] = CachedEvaluationMapper, **context: ResultT, ) -> ResultT: diff --git a/pymbolic/mapper/flattener.py b/pymbolic/mapper/flattener.py index 654c0353..e5972c0b 100644 --- a/pymbolic/mapper/flattener.py +++ b/pymbolic/mapper/flattener.py @@ -35,7 +35,11 @@ import pymbolic.primitives as p from pymbolic.mapper import IdentityMapper -from pymbolic.typing import ArithmeticExpressionT, ArithmeticOrExpressionT, ExpressionT +from pymbolic.typing import ( + ArithmeticExpression, + ArithmeticOrExpressionT, + Expression, +) class FlattenMapper(IdentityMapper[[]]): @@ -52,7 +56,7 @@ class FlattenMapper(IdentityMapper[[]]): .. automethod:: is_expr_integer_valued """ - def is_expr_integer_valued(self, expr: ExpressionT) -> bool: + def is_expr_integer_valued(self, expr: Expression) -> bool: """A user-supplied method to indicate whether a given *expr* is integer- valued. This enables additional simplifications that are not valid in general. The default implementation simply returns *False*. @@ -61,19 +65,19 @@ def is_expr_integer_valued(self, expr: ExpressionT) -> bool: """ return False - def map_sum(self, expr: p.Sum) -> ExpressionT: + def map_sum(self, expr: p.Sum) -> Expression: from pymbolic.primitives import flattened_sum return flattened_sum([ - cast(ArithmeticExpressionT, self.rec(ch)) + cast(ArithmeticExpression, self.rec(ch)) for ch in expr.children]) - def map_product(self, expr: p.Product) -> ExpressionT: + def map_product(self, expr: p.Product) -> Expression: from pymbolic.primitives import flattened_product return flattened_product([ - cast(ArithmeticExpressionT, self.rec(ch)) + cast(ArithmeticExpression, self.rec(ch)) for ch in expr.children]) - def map_quotient(self, expr: p.Quotient) -> ExpressionT: + def map_quotient(self, expr: p.Quotient) -> Expression: r_num = self.rec_arith(expr.numerator) r_den = self.rec_arith(expr.denominator) if p.is_zero(r_num): @@ -83,7 +87,7 @@ def map_quotient(self, expr: p.Quotient) -> ExpressionT: return expr.__class__(r_num, r_den) - def map_floor_div(self, expr: p.FloorDiv) -> ExpressionT: + def map_floor_div(self, expr: p.FloorDiv) -> Expression: r_num = self.rec_arith(expr.numerator) r_den = self.rec_arith(expr.denominator) if p.is_zero(r_num): @@ -95,7 +99,7 @@ def map_floor_div(self, expr: p.FloorDiv) -> ExpressionT: return expr.__class__(r_num, r_den) - def map_remainder(self, expr: p.Remainder) -> ExpressionT: + def map_remainder(self, expr: p.Remainder) -> Expression: r_num = self.rec_arith(expr.numerator) r_den = self.rec_arith(expr.denominator) assert p.is_arithmetic_expression(r_den) @@ -108,7 +112,7 @@ def map_remainder(self, expr: p.Remainder) -> ExpressionT: return expr.__class__(r_num, r_den) - def map_power(self, expr: p.Power) -> ExpressionT: + def map_power(self, expr: p.Power) -> Expression: r_base = self.rec_arith(expr.base) r_exp = self.rec_arith(expr.exponent) diff --git a/pymbolic/mapper/stringifier.py b/pymbolic/mapper/stringifier.py index a3ccd3fc..7e14aabf 100644 --- a/pymbolic/mapper/stringifier.py +++ b/pymbolic/mapper/stringifier.py @@ -30,7 +30,7 @@ import pymbolic.primitives as p from pymbolic.mapper import CachedMapper, Mapper, P -from pymbolic.typing import ExpressionT +from pymbolic.typing import Expression if TYPE_CHECKING: @@ -95,11 +95,11 @@ class StringifyMapper(Mapper[str, Concatenate[int, P]]): """A mapper to turn an expression tree into a string. - :class:`pymbolic.Expression.__str__` is often implemented using + :class:`pymbolic.ExpressionNode.__str__` is often implemented using this mapper. - When it encounters an unsupported :class:`pymbolic.Expression` - subclass, it calls its :meth:`pymbolic.Expression.make_stringifier` + When it encounters an unsupported :class:`pymbolic.ExpressionNode` + subclass, it calls its :meth:`pymbolic.ExpressionNode.make_stringifier` method to get a :class:`StringifyMapper` that potentially does. """ @@ -108,7 +108,7 @@ class StringifyMapper(Mapper[str, Concatenate[int, P]]): def format(self, s: str, *args: object) -> str: return s % args - def join(self, joiner: str, seq: Sequence[ExpressionT]) -> str: + def join(self, joiner: str, seq: Sequence[Expression]) -> str: return self.format(joiner.join("%s" for _ in seq), *seq) # {{{ deprecated junk @@ -136,7 +136,7 @@ def rec_with_force_parens_around(self, expr, *args, **kwargs): def join_rec( self, joiner: str, - seq: Sequence[ExpressionT], + seq: Sequence[Expression], prec: int, *args, **kwargs, # force_with_parens_around may hide in here @@ -167,7 +167,7 @@ def join_rec( def rec_with_parens_around_types( self, - expr: ExpressionT, + expr: Expression, enclosing_prec: int, parens_around: tuple[type, ...], *args: P.args, @@ -183,7 +183,7 @@ def rec_with_parens_around_types( def join_rec_with_parens_around_types( self, joiner: str, - seq: Sequence[ExpressionT], + seq: Sequence[Expression], prec: int, parens_around_types: tuple[type, ...], *args: P.args, @@ -565,7 +565,7 @@ def map_logical_and( def map_list( self, - expr: list[ExpressionT], + expr: list[Expression], enclosing_prec: int, *args: P.args, **kwargs: P.kwargs, @@ -578,7 +578,7 @@ def map_list( def map_tuple( self, - expr: tuple[ExpressionT, ...], + expr: tuple[Expression, ...], enclosing_prec: int, *args: P.args, **kwargs: P.kwargs, @@ -773,7 +773,7 @@ class CSESplittingStringifyMapperMixin(Mapper[str, Concatenate[int, P]]): of the use of this mix-in. """ - cse_to_name: dict[ExpressionT, str] + cse_to_name: dict[Expression, str] cse_names: set[str] cse_name_list: list[tuple[str, str]] diff --git a/pymbolic/mapper/substitutor.py b/pymbolic/mapper/substitutor.py index b584d0d7..7948fca7 100644 --- a/pymbolic/mapper/substitutor.py +++ b/pymbolic/mapper/substitutor.py @@ -4,7 +4,7 @@ .. autofunction:: make_subst_func .. autofunction:: substitute -.. autoclass:: Callable[[AlgebraicLeaf], ExpressionT | None] +.. autoclass:: Callable[[AlgebraicLeaf], Expression | None] References ---------- @@ -43,7 +43,7 @@ from pymbolic.mapper import CachedIdentityMapper, IdentityMapper from pymbolic.primitives import AlgebraicLeaf -from pymbolic.typing import ExpressionT +from pymbolic.typing import Expression if TYPE_CHECKING: @@ -52,7 +52,7 @@ class SubstitutionMapper(IdentityMapper[[]]): def __init__( - self, subst_func: Callable[[AlgebraicLeaf], ExpressionT | None] + self, subst_func: Callable[[AlgebraicLeaf], Expression | None] ) -> None: self.subst_func = subst_func @@ -80,7 +80,7 @@ def map_lookup(self, expr): class CachedSubstitutionMapper(CachedIdentityMapper[[]], SubstitutionMapper): def __init__( - self, subst_func: Callable[[AlgebraicLeaf], ExpressionT | None] + self, subst_func: Callable[[AlgebraicLeaf], Expression | None] ) -> None: # FIXME Mypy says: # error: Argument 1 to "__init__" of "CachedMapper" has incompatible type @@ -93,11 +93,11 @@ def __init__( def make_subst_func( # "Any" here avoids the whole Mapping variance disaster # e.g. https://github.com/python/typing/issues/445 - variable_assignments: SupportsGetItem[Any, ExpressionT], -) -> Callable[[AlgebraicLeaf], ExpressionT | None]: + variable_assignments: SupportsGetItem[Any, Expression], +) -> Callable[[AlgebraicLeaf], Expression | None]: import pymbolic.primitives as primitives - def subst_func(var: AlgebraicLeaf) -> ExpressionT | None: + def subst_func(var: AlgebraicLeaf) -> Expression | None: try: return variable_assignments[var] except KeyError: @@ -113,10 +113,11 @@ def subst_func(var: AlgebraicLeaf) -> ExpressionT | None: def substitute( - expression: ExpressionT, - variable_assignments: SupportsItems[AlgebraicLeaf | str, ExpressionT] | None = None, + expression: Expression, + variable_assignments: SupportsItems[AlgebraicLeaf | str, Expression] | None + = None, mapper_cls=CachedSubstitutionMapper, - **kwargs: ExpressionT, + **kwargs: Expression, ): """ :arg mapper_cls: A :class:`type` of the substitution mapper @@ -125,7 +126,7 @@ def substitute( if variable_assignments is None: # "Any" here avoids pointless grief about variance # e.g. https://github.com/python/typing/issues/445 - v_ass_copied: dict[Any, ExpressionT] = {} + v_ass_copied: dict[Any, Expression] = {} else: v_ass_copied = dict(variable_assignments.items()) diff --git a/pymbolic/parser.py b/pymbolic/parser.py index fd4248f2..8a12e471 100644 --- a/pymbolic/parser.py +++ b/pymbolic/parser.py @@ -1,6 +1,6 @@ from __future__ import annotations -from pymbolic.typing import ExpressionT +from pymbolic.typing import Expression __copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner" @@ -561,7 +561,7 @@ def parse_arglist(self, pstate): comma_allowed = True - def __call__(self, expr_str: str, min_precedence: int = 0) -> ExpressionT: + def __call__(self, expr_str: str, min_precedence: int = 0) -> Expression: lex_result = [(tag, s, idx, match_obj) for (tag, s, idx, match_obj) in pytools.lex.lex( self.lex_table, expr_str, diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index cd9beaf1..57623a59 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -26,6 +26,7 @@ import re from collections.abc import Callable, Iterable, Mapping from dataclasses import dataclass, fields +from functools import partial from sys import intern from typing import ( TYPE_CHECKING, @@ -42,8 +43,10 @@ from immutabledict import immutabledict from typing_extensions import TypeIs, dataclass_transform +from pytools import module_getattr_for_deprecations + from . import traits -from .typing import ArithmeticExpressionT, ExpressionT, NumberT, ScalarT +from .typing import ArithmeticExpression, Expression as _Expression, Number, Scalar if TYPE_CHECKING: @@ -56,7 +59,7 @@ .. currentmodule:: pymbolic -.. autoclass:: Expression +.. autoclass:: ExpressionNode .. autofunction:: expr_dataclass @@ -296,9 +299,9 @@ An instance of a :func:`~dataclasses.dataclass`. -.. class:: ArithmeticExpressionT +.. class:: ArithmeticExpression - See :class:`pymbolic.ArithmeticExpressionT` + See :class:`pymbolic.ArithmeticExpression` .. class:: _T @@ -320,9 +323,9 @@ See :class:`pymbolic.Variable`. -.. class:: ExpressionT +.. class:: Expression - See :class:`pymbolic.ExpressionT`. + See :data:`pymbolic.typing.Expression`. .. currentmodule:: pymbolic @@ -370,11 +373,11 @@ def __get__(self, owner_self, owner_cls): class _AttributeLookupCreator: - """Helper used by :attr:`pymbolic.Expression.a` to create lookups. + """Helper used by :attr:`pymbolic.ExpressionNode.a` to create lookups. .. automethod:: __getattr__ """ - def __init__(self, aggregate: ExpressionT) -> None: + def __init__(self, aggregate: _Expression) -> None: self.aggregate = aggregate def __getattr__(self, name: str) -> Lookup: @@ -383,10 +386,10 @@ def __getattr__(self, name: str) -> Lookup: @dataclass(frozen=True) class EmptyOK: - child: ExpressionT + child: _Expression -class Expression: +class ExpressionNode: """Superclass for parts of a mathematical expression. Overrides operators to implicitly construct :class:`~pymbolic.primitives.Sum`, :class:`~pymbolic.primitives.Product` and other expressions. @@ -606,10 +609,10 @@ def __rand__(self, other: object) -> BitwiseAnd: # {{{ misc - def __neg__(self) -> ArithmeticExpressionT: + def __neg__(self) -> ArithmeticExpression: return -1*self - def __pos__(self) -> ArithmeticExpressionT: + def __pos__(self) -> ArithmeticExpression: return self def __call__(self, *args, **kwargs) -> Call | CallWithKwargs: @@ -623,7 +626,7 @@ def __call__(self, *args, **kwargs) -> Call | CallWithKwargs: # Subscript has an attribute 'index' which can't coexist with this. # Thus we're hiding this from mypy until it goes away. - def index(self, subscript: Expression) -> Expression: + def index(self, subscript: ExpressionNode) -> ExpressionNode: """Return an expression representing ``self[subscript]``. .. versionadded:: 2014.3 @@ -634,7 +637,7 @@ def index(self, subscript: Expression) -> Expression: return self[subscript] - def __getitem__(self, subscript: ExpressionT | EmptyOK) -> Expression: + def __getitem__(self, subscript: _Expression | EmptyOK) -> ExpressionNode: """Return an expression representing ``self[subscript]``. """ if isinstance(subscript, EmptyOK): @@ -699,7 +702,7 @@ def strify_child(child, limit): ", ".join(strify_child(i, limit-1) for i in child), "," if len(child) == 1 else "") - elif isinstance(child, Expression): + elif isinstance(child, ExpressionNode): return child._safe_repr(limit=limit-1) else: return repr(child) @@ -725,7 +728,8 @@ def __repr__(self) -> str: # This custom warning deduplication mechanism became necessary because the # sheer amount of warnings ended up leading to out-of-memory situations # with pytest which bufered all the warnings. - _deprecation_warnings_issued: ClassVar[set[tuple[type[Expression], str]]] = set() + _deprecation_warnings_issued: ClassVar[set[tuple[type[ExpressionNode], str]]] \ + = set() def __eq__(self, other) -> bool: """Provides equality testing with quick positive and negative paths @@ -879,7 +883,7 @@ def __gt__(self, other) -> NoReturn: # }}} - def __abs__(self) -> Expression: + def __abs__(self) -> ExpressionNode: return Call(Variable("abs"), (self,)) def __iter__(self): @@ -1084,17 +1088,17 @@ def expr_dataclass( hash: bool = True, ) -> Callable[[type[_T]], type[_T]]: r"""A class decorator that makes the class a :func:`~dataclasses.dataclass` - while also adding functionality needed for :class:`Expression` nodes. + while also adding functionality needed for :class:`ExpressionNode`. Specifically, it adds cached hashing, equality comparisons with ``self is other`` shortcuts as well as some methods/attributes for backward compatibility (e.g. ``__getinitargs__``, ``init_arg_names``). - It also adds a :attr:`Expression.mapper_method` based on the class name - if not already present. If :attr:`~Expression.mapper_method` is inherited, + It also adds a :attr:`ExpressionNode.mapper_method` based on the class name + if not already present. If :attr:`~ExpressionNode.mapper_method` is inherited, it will be viewed as unset and replaced. Note that the class to which this decorator is applied need not be - a subclass of :class:`~pymbolic.Expression`. + a subclass of :class:`ExpressionNode`. .. versionadded:: 2024.1 """ @@ -1120,7 +1124,7 @@ def map_cls(cls: type[_T]) -> type[_T]: @expr_dataclass() -class AlgebraicLeaf(Expression): +class AlgebraicLeaf(ExpressionNode): """An expression that serves as a leaf for arithmetic evaluation. This may end up having child nodes still, but they're not reached by ways of arithmetic.""" @@ -1179,13 +1183,12 @@ class Call(AlgebraicLeaf): .. autoattribute:: function .. autoattribute:: parameters """ - function: ExpressionT - """A :class:`Expression` that evaluates to a function.""" + function: _Expression + """Evaluates to the function to be called.""" - parameters: tuple[ExpressionT, ...] + parameters: tuple[_Expression, ...] """ - A :class:`tuple` of positional parameters, each element - of which is a :class:`Expression` or a constant. + A :class:`tuple` of positional parameters. """ @@ -1198,17 +1201,15 @@ class CallWithKwargs(AlgebraicLeaf): .. autoattribute:: kw_parameters """ - function: ExpressionT - """An :class:`Expression` that evaluates to a function.""" + function: _Expression + """Evaluates to the function to be called.""" - parameters: tuple[ExpressionT, ...] - """A :class:`tuple` of positional parameters, each element - of which is a :class:`Expression` or a constant. + parameters: tuple[_Expression, ...] + """A :class:`tuple` of positional parameters. """ - kw_parameters: Mapping[str, ExpressionT] - """A dictionary mapping names to arguments, each - of which is a :class:`Expression` or a constant. + kw_parameters: Mapping[str, _Expression] + """A dictionary mapping names to arguments. """ def __post_init__(self): @@ -1228,11 +1229,11 @@ def __post_init__(self): class Subscript(AlgebraicLeaf): """An array subscript.""" - aggregate: ExpressionT - index: ExpressionT + aggregate: _Expression + index: _Expression @property - def index_tuple(self) -> tuple[ExpressionT, ...]: + def index_tuple(self) -> tuple[_Expression, ...]: """ Return :attr:`index` wrapped in a single-element tuple, if it is not already a tuple. @@ -1248,7 +1249,7 @@ def index_tuple(self) -> tuple[ExpressionT, ...]: class Lookup(AlgebraicLeaf): """Access to an attribute of an *aggregate*, such as an attribute of a class.""" - aggregate: ExpressionT + aggregate: _Expression name: str # }}} @@ -1257,7 +1258,7 @@ class Lookup(AlgebraicLeaf): # {{{ arithmetic primitives @expr_dataclass() -class Sum(Expression): +class Sum(ExpressionNode): """ .. autoattribute:: children @@ -1267,7 +1268,7 @@ class Sum(Expression): .. automethod:: __bool__ """ - children: tuple[ExpressionT, ...] + children: tuple[_Expression, ...] def __add__(self, other): if not is_valid_operand(other): @@ -1310,7 +1311,7 @@ def __bool__(self): @expr_dataclass() -class Product(Expression): +class Product(ExpressionNode): """ .. autoattribute:: children @@ -1319,7 +1320,7 @@ class Product(Expression): .. automethod:: __bool__ """ - children: tuple[ExpressionT, ...] + children: tuple[_Expression, ...] def __mul__(self, other): if not is_valid_operand(other): @@ -1353,25 +1354,25 @@ def __bool__(self): @expr_dataclass() -class Min(Expression): +class Min(ExpressionNode): """ .. autoattribute:: children """ - children: tuple[ExpressionT, ...] + children: tuple[_Expression, ...] @expr_dataclass() -class Max(Expression): +class Max(ExpressionNode): """ .. autoattribute:: children """ - children: tuple[ExpressionT, ...] + children: tuple[_Expression, ...] @expr_dataclass() -class QuotientBase(Expression): - numerator: ArithmeticExpressionT - denominator: ArithmeticExpressionT +class QuotientBase(ExpressionNode): + numerator: ArithmeticExpression + denominator: ArithmeticExpression @property def num(self): @@ -1389,7 +1390,7 @@ def __bool__(self): @expr_dataclass() class Quotient(QuotientBase): - """Bases: :class:`~pymbolic.Expression` + """Bases: :class:`~pymbolic.ExpressionNode` .. autoattribute:: numerator .. autoattribute:: denominator @@ -1398,7 +1399,7 @@ class Quotient(QuotientBase): @expr_dataclass() class FloorDiv(QuotientBase): - """Bases: :class:`~pymbolic.Expression` + """Bases: :class:`~pymbolic.ExpressionNode` .. autoattribute:: numerator .. autoattribute:: denominator @@ -1407,7 +1408,7 @@ class FloorDiv(QuotientBase): @expr_dataclass() class Remainder(QuotientBase): - """Bases: :class:`~pymbolic.Expression` + """Bases: :class:`~pymbolic.ExpressionNode` .. autoattribute:: numerator .. autoattribute:: denominator @@ -1415,14 +1416,14 @@ class Remainder(QuotientBase): @expr_dataclass() -class Power(Expression): +class Power(ExpressionNode): """ .. autoattribute:: base .. autoattribute:: exponent """ - base: ArithmeticExpressionT - exponent: ArithmeticExpressionT + base: ArithmeticExpression + exponent: ArithmeticExpression # }}} @@ -1430,14 +1431,14 @@ class Power(Expression): # {{{ shift operators @expr_dataclass() -class _ShiftOperator(Expression): - shiftee: ExpressionT - shift: ExpressionT +class _ShiftOperator(ExpressionNode): + shiftee: _Expression + shift: _Expression @expr_dataclass() class LeftShift(_ShiftOperator): - """Bases: :class:`~pymbolic.Expression`. + """Bases: :class:`~pymbolic.ExpressionNode`. .. autoattribute:: shiftee .. autoattribute:: shift @@ -1446,7 +1447,7 @@ class LeftShift(_ShiftOperator): @expr_dataclass() class RightShift(_ShiftOperator): - """Bases: :class:`~pymbolic.Expression`. + """Bases: :class:`~pymbolic.ExpressionNode`. .. autoattribute:: shiftee .. autoattribute:: shift @@ -1458,37 +1459,37 @@ class RightShift(_ShiftOperator): # {{{ bitwise operators @expr_dataclass() -class BitwiseNot(Expression): +class BitwiseNot(ExpressionNode): """ .. autoattribute:: child """ - child: ExpressionT + child: _Expression @expr_dataclass() -class BitwiseOr(Expression): +class BitwiseOr(ExpressionNode): """ .. autoattribute:: children """ - children: tuple[ExpressionT, ...] + children: tuple[_Expression, ...] @expr_dataclass() -class BitwiseXor(Expression): +class BitwiseXor(ExpressionNode): """ .. autoattribute:: children """ - children: tuple[ExpressionT, ...] + children: tuple[_Expression, ...] @expr_dataclass() -class BitwiseAnd(Expression): +class BitwiseAnd(ExpressionNode): """ .. autoattribute:: children """ - children: tuple[ExpressionT, ...] + children: tuple[_Expression, ...] # }}} @@ -1496,7 +1497,7 @@ class BitwiseAnd(Expression): # {{{ comparisons, logic, conditionals @expr_dataclass() -class Comparison(Expression): +class Comparison(ExpressionNode): """ .. autoattribute:: left .. autoattribute:: operator @@ -1505,18 +1506,19 @@ class Comparison(Expression): .. note:: Unlike other expressions, comparisons are not implicitly constructed by - comparing :class:`Expression` objects. See :meth:`pymbolic.Expression.eq`. + comparing :class:`ExpressionNode` objects. + See :meth:`pymbolic.ExpressionNode.eq` and related. .. autoattribute:: operator_to_name .. autoattribute:: name_to_operator """ - left: ExpressionT + left: _Expression operator: str """One of ``[">", ">=", "==", "!=", "<", "<="]``.""" - right: ExpressionT + right: _Expression operator_to_name: ClassVar[dict[str, str]] = { "==": "eq", @@ -1547,42 +1549,42 @@ def __post_init__(self): @expr_dataclass() -class LogicalNot(Expression): +class LogicalNot(ExpressionNode): """ .. autoattribute:: child """ - child: ExpressionT + child: _Expression @expr_dataclass() -class LogicalOr(Expression): +class LogicalOr(ExpressionNode): """ .. autoattribute:: children """ - children: tuple[ExpressionT, ...] + children: tuple[_Expression, ...] @expr_dataclass() -class LogicalAnd(Expression): +class LogicalAnd(ExpressionNode): """ .. autoattribute:: children """ - children: tuple[ExpressionT, ...] + children: tuple[_Expression, ...] @expr_dataclass() -class If(Expression): +class If(ExpressionNode): """ .. autoattribute:: condition .. autoattribute:: then .. autoattribute:: else_ """ - condition: ExpressionT - then: ExpressionT - else_: ExpressionT + condition: _Expression + then: _Expression + else_: _Expression # }}} @@ -1617,7 +1619,7 @@ class cse_scope: # noqa @expr_dataclass() -class CommonSubexpression(Expression): +class CommonSubexpression(ExpressionNode): """A helper for code generation and caching. Denotes a subexpression that should only be evaluated once. If, in code generation, it is assigned to a variable, a name starting with :attr:`prefix` should be used. @@ -1630,7 +1632,7 @@ class CommonSubexpression(Expression): See :class:`pymbolic.mapper.c_code.CCodeMapper` for an example. """ - child: ExpressionT + child: _Expression prefix: str | None = None scope: str = cse_scope.EVALUATION """ @@ -1657,7 +1659,7 @@ def get_extra_properties(self): @expr_dataclass() -class Substitution(Expression): +class Substitution(ExpressionNode): """Work-alike of :class:`~sympy.core.function.Subs`. .. autoattribute:: child @@ -1665,31 +1667,32 @@ class Substitution(Expression): .. autoattribute:: values """ - child: ExpressionT + child: _Expression variables: tuple[str, ...] - values: tuple[ExpressionT, ...] + values: tuple[_Expression, ...] @expr_dataclass() -class Derivative(Expression): +class Derivative(ExpressionNode): """Work-alike of sympy's :class:`~sympy.core.function.Derivative`. .. autoattribute:: child .. autoattribute:: variables """ - child: ExpressionT + child: _Expression variables: tuple[str, ...] SliceChildrenT: TypeAlias = (tuple[()] - | tuple[ExpressionT | None] - | tuple[ExpressionT | None, ExpressionT | None] - | tuple[ExpressionT | None, ExpressionT | None, ExpressionT | None]) + | tuple[_Expression | None] + | tuple[_Expression | None, _Expression | None] + | tuple[_Expression | None, _Expression | None, _Expression + | None]) @expr_dataclass() -class Slice(Expression): +class Slice(ExpressionNode): """A slice expression as in a[1:7]. .. autoattribute:: children @@ -1772,7 +1775,7 @@ def subscript(expression, index): return Subscript(expression, index) -def flattened_sum(terms: Iterable[ArithmeticExpressionT]) -> ArithmeticExpressionT: +def flattened_sum(terms: Iterable[ArithmeticExpression]) -> ArithmeticExpression: r"""Recursively flattens all the top level :class:`Sum`\ s in *terms*. :arg terms: an :class:`~collections.abc.Iterable` of expressions. @@ -1789,7 +1792,7 @@ def flattened_sum(terms: Iterable[ArithmeticExpressionT]) -> ArithmeticExpressio continue if isinstance(item, Sum): - ch = cast(tuple[ArithmeticExpressionT], item.children) + ch = cast(tuple[ArithmeticExpression], item.children) queue.extend(ch) else: done.append(item) @@ -1809,7 +1812,7 @@ def linear_combination(coefficients, expressions): if coefficient and expression) -def flattened_product(terms: Iterable[ArithmeticExpressionT]) -> ArithmeticExpressionT: +def flattened_product(terms: Iterable[ArithmeticExpression]) -> ArithmeticExpression: r"""Recursively flattens all the top level :class:`Product`\ s in *terms*. This operation does not change the order of the terms in the products, so @@ -1831,7 +1834,7 @@ def flattened_product(terms: Iterable[ArithmeticExpressionT]) -> ArithmeticExpre continue if isinstance(item, Product): - ch = cast(tuple[ArithmeticExpressionT], item.children) + ch = cast(tuple[ArithmeticExpression], item.children) queue.extend(ch) else: done.append(item) @@ -1873,7 +1876,7 @@ def quotient(numerator, denominator): global VALID_OPERANDS VALID_CONSTANT_CLASSES: tuple[type, ...] = (int, float, complex) _BOOL_CLASSES: tuple[type, ...] = (bool,) -VALID_OPERANDS = (Expression,) +VALID_OPERANDS = (ExpressionNode,) try: import numpy @@ -1883,20 +1886,20 @@ def quotient(numerator, denominator): pass -def is_constant(value: object) -> TypeIs[ScalarT]: +def is_constant(value: object) -> TypeIs[Scalar]: return isinstance(value, VALID_CONSTANT_CLASSES) -def is_number(value: object) -> TypeIs[NumberT]: +def is_number(value: object) -> TypeIs[Number]: return (not isinstance(value, _BOOL_CLASSES) and isinstance(value, VALID_CONSTANT_CLASSES)) -def is_valid_operand(value: object) -> TypeIs[ExpressionT]: +def is_valid_operand(value: object) -> TypeIs[_Expression]: return isinstance(value, VALID_OPERANDS) or is_constant(value) -def is_arithmetic_expression(value: object) -> TypeIs[ArithmeticExpressionT]: +def is_arithmetic_expression(value: object) -> TypeIs[ArithmeticExpression]: return not isinstance(value, _BOOL_CLASSES) and is_valid_operand(value) @@ -1928,9 +1931,9 @@ def is_zero(value: object) -> bool: return not is_nonzero(value) -def wrap_in_cse(expr: ExpressionT, +def wrap_in_cse(expr: _Expression, prefix: str | None = None, - scope: str | None = None) -> ExpressionT: + scope: str | None = None) -> _Expression: warn("'wrap_in_cse' is deprecated and will be removed in 2025. Use " "'make_common_subexpression' with the `wrap_vars=False` flag instead.", DeprecationWarning, stacklevel=2) @@ -1938,11 +1941,11 @@ def wrap_in_cse(expr: ExpressionT, return make_common_subexpression(expr, prefix, scope, wrap_vars=False) -def make_common_subexpression(expr: ExpressionT, +def make_common_subexpression(expr: _Expression, prefix: str | None = None, scope: str | None = None, *, - wrap_vars: bool = True) -> ExpressionT: + wrap_vars: bool = True) -> _Expression: """Wrap *expr* in a :class:`CommonSubexpression` with *prefix*. If *expr* is a :mod:`numpy` object array, each individual entry is instead @@ -2069,4 +2072,9 @@ def variables(s): # }}} +__getattr__ = partial(module_getattr_for_deprecations, __name__, { + "Expression": ("ExpressionNode", ExpressionNode, 2026), + }) + + # vim: foldmethod=marker diff --git a/pymbolic/rational.py b/pymbolic/rational.py index 22465825..7d3c222e 100644 --- a/pymbolic/rational.py +++ b/pymbolic/rational.py @@ -29,7 +29,7 @@ import pymbolic.traits as traits -class Rational(primitives.Expression): +class Rational(primitives.ExpressionNode): def __init__(self, numerator, denominator=1): d_unit = traits.traits(denominator).get_unit(denominator) numerator /= d_unit @@ -74,9 +74,9 @@ def __add__(self, other): gcd = t.gcd(newden, newnum) return primitives.quotient(newnum/gcd, newden/gcd) except traits.NoTraitsError: - return primitives.Expression.__add__(self, other) + return primitives.ExpressionNode.__add__(self, other) except traits.NoCommonTraitsError: - return primitives.Expression.__add__(self, other) + return primitives.ExpressionNode.__add__(self, other) __radd__ = __add__ @@ -106,9 +106,9 @@ def __mul__(self, other): return Rational(new_num, new_denom) except traits.NoTraitsError: - return primitives.Expression.__mul__(self, other) + return primitives.ExpressionNode.__mul__(self, other) except traits.NoCommonTraitsError: - return primitives.Expression.__mul__(self, other) + return primitives.ExpressionNode.__mul__(self, other) __rmul__ = __mul__ diff --git a/pymbolic/typing.py b/pymbolic/typing.py index a16ed6cf..05630cf5 100644 --- a/pymbolic/typing.py +++ b/pymbolic/typing.py @@ -4,29 +4,62 @@ Typing helpers -------------- -.. autoclass:: BoolT -.. autoclass:: NumberT -.. autoclass:: ScalarT -.. autoclass:: ArithmeticExpressionT +.. autoclass:: Bool +.. autoclass:: Number +.. autoclass:: Scalar +.. autoclass:: ArithmeticExpression - A narrower type alias than :class:`ExpressionT` that is returned by + A narrower type alias than :class:`Expression` that is returned by arithmetic operators, to allow continue doing arithmetic with the result of arithmetic. -.. autoclass:: ExpressionT - .. currentmodule:: pymbolic.typing +.. autoclass:: Expression + +.. note:: + + For backward compatibility, ``pymbolic.Expression`` + will alias :class:`pymbolic.primitives.ExpressionNode` for now. Once its deprecation + period is up, it will be removed, and then, in the further future, + ``pymbolic.Expression`` may become this type alias. + .. autoclass:: ArithmeticOrExpressionT - A type variable that can be either :data:`ArithmeticExpressionT` - or :data:`ExpressionT`. + A type variable that can be either :data:`ArithmeticExpression` + or :data:`Expression`. """ from __future__ import annotations + +__copyright__ = "Copyright (C) 2024 University of Illinois Board of Trustees" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from functools import partial from typing import TYPE_CHECKING, TypeAlias, TypeVar, Union +from pytools import module_getattr_for_deprecations + # FIXME: This is a lie. Many more constant types (e.g. numpy and such) # are in practical use and completely fine. We cannot really add in numpy @@ -45,7 +78,7 @@ # https://github.com/python/typeshed/blob/119cd09655dcb4ed7fb2021654ba809b8d88846f/stdlib/numbers.pyi if TYPE_CHECKING: - from pymbolic.primitives import Expression + from pymbolic.primitives import ExpressionNode # Experience with depending packages showed that including Decimal and Fraction # from the stdlib was more trouble than it's worth because those types don't cleanly @@ -59,39 +92,48 @@ if TYPE_CHECKING: # Yes, type-checking pymbolic will require numpy. That's OK. import numpy as np - BoolT = bool | np.bool_ - IntegerT: TypeAlias = int | np.integer - InexactNumberT: TypeAlias = _StdlibInexactNumberT | np.inexact + Bool = bool | np.bool_ + Integer: TypeAlias = int | np.integer + InexactNumber: TypeAlias = _StdlibInexactNumberT | np.inexact else: try: import numpy as np except ImportError: - BoolT = bool - IntegerT: TypeAlias = int - InexactNumberT: TypeAlias = _StdlibInexactNumberT + Bool = bool + Integer: TypeAlias = int + InexactNumber: TypeAlias = _StdlibInexactNumberT else: - BoolT = bool | np.bool_ - IntegerT: TypeAlias = int | np.integer - InexactNumberT: TypeAlias = _StdlibInexactNumberT | np.inexact + Bool = bool | np.bool_ + Integer: TypeAlias = int | np.integer + InexactNumber: TypeAlias = _StdlibInexactNumberT | np.inexact -NumberT: TypeAlias = IntegerT | InexactNumberT -ScalarT: TypeAlias = NumberT | BoolT +Number: TypeAlias = Integer | InexactNumber +Scalar: TypeAlias = Number | Bool -_ScalarOrExpression = Union[ScalarT, "Expression"] -ArithmeticExpressionT: TypeAlias = Union[NumberT, "Expression"] +_ScalarOrExpression = Union[Scalar, "ExpressionNode"] +ArithmeticExpression: TypeAlias = Union[Number, "ExpressionNode"] -ExpressionT: TypeAlias = _ScalarOrExpression | tuple["ExpressionT", ...] +Expression: TypeAlias = _ScalarOrExpression | tuple["Expression", ...] ArithmeticOrExpressionT = TypeVar( "ArithmeticOrExpressionT", - ArithmeticExpressionT, - ExpressionT) + ArithmeticExpression, + Expression) T = TypeVar("T") +__getattr__ = partial(module_getattr_for_deprecations, __name__, { + "ArithmeticExpressionT": ("ArithmeticExpression", ArithmeticExpression, 2026), + "ExpressionT": ("Expression", Expression, 2026), + "IntegerT": ("Integer", Integer, 2026), + "ScalarT": ("Scalar", Scalar, 2026), + "BoolT": ("Bool", Bool, 2026), + }) + + def not_none(x: T | None) -> T: assert x is not None return x diff --git a/pyproject.toml b/pyproject.toml index e4f2c6ea..f6e1ab96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ classifiers = [ ] dependencies = [ "immutabledict", - "pytools>=2022.1.14", + "pytools>=2024.1.16", # for dataclass_transform, TypeAlias, deprecated "typing-extensions>=4.5", "useful-types", diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index 0d09fc9c..092c4718 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -3,7 +3,7 @@ from pymbolic.mapper.evaluator import evaluate_kw from pymbolic.mapper.flattener import FlattenMapper from pymbolic.mapper.stringifier import StringifyMapper -from pymbolic.typing import ExpressionT +from pymbolic.typing import Expression __copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner" @@ -567,7 +567,7 @@ def test_pickle(): assert expr == pickled -class OldTimeyExpression(prim.Expression): +class OldTimeyExpression(prim.ExpressionNode): init_arg_names = () def __getinitargs__(self): @@ -921,7 +921,7 @@ def __init__(self, cached_mapper, walk_call_functions=True): self.walk_call_functions = walk_call_functions def post_visit(self, expr): - if isinstance(expr, prim.Expression): + if isinstance(expr, prim.ExpressionNode): assert (self.cached_mapper.get_cache_key(expr) in self.cached_mapper._cache) @@ -1038,7 +1038,7 @@ def test_python_ast_interop_roundtrip(): @prim.expr_dataclass() class CustomOperator: - child: ExpressionT + child: Expression def make_stringifier(self, originating_stringifier=None): return OperatorStringifier() @@ -1058,7 +1058,7 @@ def test_derived_stringifier() -> None: # {{{ test_flatten class IntegerFlattenMapper(FlattenMapper): - def is_expr_integer_valued(self, expr: ExpressionT) -> bool: + def is_expr_integer_valued(self, expr: Expression) -> bool: return True diff --git a/test/testlib.py b/test/testlib.py index 7e697374..efcfb898 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -47,7 +47,7 @@ def with_increased_depth(self): def _generate_random_expr_inner( - context: RandomExpressionGeneratorContext) -> prim.Expression: + context: RandomExpressionGeneratorContext) -> prim.ExpressionNode: if context.current_depth >= context.max_depth: # force expression to be a leaf type @@ -92,7 +92,7 @@ def _generate_random_expr_inner( raise NotImplementedError(expr_type) -def generate_random_expression(seed: int, max_depth: int = 8) -> prim.Expression: +def generate_random_expression(seed: int, max_depth: int = 8) -> prim.ExpressionNode: from numpy.random import default_rng rng = default_rng(seed) vng = UniqueNameGenerator()