From c8e444fd09d06504a5c9b38c48b8a04a07ba741d Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Sun, 6 Oct 2024 12:51:21 -0500 Subject: [PATCH] Do not simplify in overloaded operators --- pymbolic/primitives.py | 114 +++++++++++------------------------------ 1 file changed, 29 insertions(+), 85 deletions(-) diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index 9ebe5575..07b5c4d2 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -376,140 +376,84 @@ def init_arg_names(self) -> tuple[str, ...]: # {{{ arithmetic - def __add__(self, other: object) -> ArithmeticExpressionT: + def __add__(self, other: object) -> Sum: if not is_arithmetic_expression(other): return NotImplemented - if is_nonzero(other): - if self: - if isinstance(other, Sum): - return Sum((self, *other.children)) - else: - return Sum((self, other)) - else: - return other - else: - return self - - def __radd__(self, other: object) -> ArithmeticExpressionT: - assert is_number(other) - if is_nonzero(other): - if self: - return Sum((other, self)) - else: - return other - else: - return self + return Sum((self, other)) - def __sub__(self, other: object) -> ArithmeticExpressionT: - if not is_valid_operand(other): + def __radd__(self, other: object) -> Sum: + if not is_arithmetic_expression(other): return NotImplemented + return Sum((other, self)) - if is_nonzero(other): - return self.__add__(-cast(NumberT, other)) - else: - return self - - def __rsub__(self, other: object) -> ArithmeticExpressionT: - if not is_constant(other): + def __sub__(self, other: object) -> Sum: + if not is_arithmetic_expression(other): return NotImplemented + return Sum((self, -other)) - if is_nonzero(other): - return Sum((other, -self)) - else: - return -self + def __rsub__(self, other: object) -> Sum: + if not is_arithmetic_expression(other): + return NotImplemented + return Sum((other, -self)) - def __mul__(self, other: object) -> ArithmeticExpressionT: + def __mul__(self, other: object) -> Product: if not is_valid_operand(other): return NotImplemented - other = cast(NumberT, other) - if is_zero(other - 1): - return self - elif is_zero(other): - return 0 - else: - return Product((self, other)) + return Product((self, other)) - def __rmul__(self, other: object) -> ArithmeticExpressionT: - if not is_constant(other): + def __rmul__(self, other: object) -> Product: + if not is_valid_operand(other): return NotImplemented - if is_zero(other-1): - return self - elif is_zero(other): - return 0 - else: - return Product((other, self)) + return Product((other, self)) - def __div__(self, other: object) -> ArithmeticExpressionT: + def __truediv__(self, other: object) -> Quotient: if not is_valid_operand(other): return NotImplemented - other = cast(NumberT, other) - if is_zero(other-1): - return self - return quotient(self, other) - __truediv__ = __div__ + return Quotient(self, other) - def __rdiv__(self, other: object) -> ArithmeticExpressionT: + def __rtruediv__(self, other: object) -> Quotient: if not is_valid_operand(other): return NotImplemented - if is_zero(other): - return 0 - return quotient(other, self) - __rtruediv__ = __rdiv__ + return Quotient(other, self) - def __floordiv__(self, other: object) -> ArithmeticExpressionT: + def __floordiv__(self, other: object) -> FloorDiv: if not is_valid_operand(other): return NotImplemented - other = cast(NumberT, other) - if is_zero(other-1): - return self return FloorDiv(self, other) - def __rfloordiv__(self, other: object) -> ArithmeticExpressionT: + def __rfloordiv__(self, other: object) -> FloorDiv: if not is_arithmetic_expression(other): return NotImplemented - if is_zero(self-1): - return other return FloorDiv(other, self) - def __mod__(self, other: object) -> ArithmeticExpressionT: + def __mod__(self, other: object) -> Remainder: if not is_valid_operand(other): return NotImplemented - other = cast(NumberT, other) - if is_zero(other-1): - return 0 return Remainder(self, other) - def __rmod__(self, other: object) -> ArithmeticExpressionT: + def __rmod__(self, other: object) -> Remainder: if not is_valid_operand(other): return NotImplemented return Remainder(other, self) - def __pow__(self, other: object) -> ArithmeticExpressionT: + def __pow__(self, other: object) -> Power: if not is_valid_operand(other): return NotImplemented - other = cast(NumberT, other) - if is_zero(other): # exponent zero - return 1 - elif is_zero(other-1): # exponent one - return self return Power(self, other) - def __rpow__(self, other: object) -> ArithmeticExpressionT: - assert is_constant(other) + def __rpow__(self, other: object) -> Power: + if not is_valid_operand(other): + return NotImplemented - if is_zero(other): # base zero - return 0 - elif is_zero(other-1): # base one - return 1 return Power(other, self) # }}}