Skip to content

Commit

Permalink
Do not simplify in overloaded operators
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Oct 6, 2024
1 parent 5e05471 commit c8e444f
Showing 1 changed file with 29 additions and 85 deletions.
114 changes: 29 additions & 85 deletions pymbolic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,140 +376,84 @@ def init_arg_names(self) -> tuple[str, ...]:

# {{{ arithmetic

def __add__(self, other: object) -> ArithmeticExpressionT:
def __add__(self, other: object) -> Sum:
if not is_arithmetic_expression(other):
return NotImplemented
if is_nonzero(other):
if self:
if isinstance(other, Sum):
return Sum((self, *other.children))
else:
return Sum((self, other))
else:
return other
else:
return self

def __radd__(self, other: object) -> ArithmeticExpressionT:
assert is_number(other)
if is_nonzero(other):
if self:
return Sum((other, self))
else:
return other
else:
return self
return Sum((self, other))

def __sub__(self, other: object) -> ArithmeticExpressionT:
if not is_valid_operand(other):
def __radd__(self, other: object) -> Sum:
if not is_arithmetic_expression(other):
return NotImplemented
return Sum((other, self))

if is_nonzero(other):
return self.__add__(-cast(NumberT, other))
else:
return self

def __rsub__(self, other: object) -> ArithmeticExpressionT:
if not is_constant(other):
def __sub__(self, other: object) -> Sum:
if not is_arithmetic_expression(other):
return NotImplemented
return Sum((self, -other))

if is_nonzero(other):
return Sum((other, -self))
else:
return -self
def __rsub__(self, other: object) -> Sum:
if not is_arithmetic_expression(other):
return NotImplemented
return Sum((other, -self))

def __mul__(self, other: object) -> ArithmeticExpressionT:
def __mul__(self, other: object) -> Product:
if not is_valid_operand(other):
return NotImplemented

other = cast(NumberT, other)
if is_zero(other - 1):
return self
elif is_zero(other):
return 0
else:
return Product((self, other))
return Product((self, other))

def __rmul__(self, other: object) -> ArithmeticExpressionT:
if not is_constant(other):
def __rmul__(self, other: object) -> Product:
if not is_valid_operand(other):
return NotImplemented

if is_zero(other-1):
return self
elif is_zero(other):
return 0
else:
return Product((other, self))
return Product((other, self))

def __div__(self, other: object) -> ArithmeticExpressionT:
def __truediv__(self, other: object) -> Quotient:
if not is_valid_operand(other):
return NotImplemented

other = cast(NumberT, other)
if is_zero(other-1):
return self
return quotient(self, other)
__truediv__ = __div__
return Quotient(self, other)

def __rdiv__(self, other: object) -> ArithmeticExpressionT:
def __rtruediv__(self, other: object) -> Quotient:
if not is_valid_operand(other):
return NotImplemented

if is_zero(other):
return 0
return quotient(other, self)
__rtruediv__ = __rdiv__
return Quotient(other, self)

def __floordiv__(self, other: object) -> ArithmeticExpressionT:
def __floordiv__(self, other: object) -> FloorDiv:
if not is_valid_operand(other):
return NotImplemented

other = cast(NumberT, other)
if is_zero(other-1):
return self
return FloorDiv(self, other)

def __rfloordiv__(self, other: object) -> ArithmeticExpressionT:
def __rfloordiv__(self, other: object) -> FloorDiv:
if not is_arithmetic_expression(other):
return NotImplemented

if is_zero(self-1):
return other
return FloorDiv(other, self)

def __mod__(self, other: object) -> ArithmeticExpressionT:
def __mod__(self, other: object) -> Remainder:
if not is_valid_operand(other):
return NotImplemented

other = cast(NumberT, other)
if is_zero(other-1):
return 0
return Remainder(self, other)

def __rmod__(self, other: object) -> ArithmeticExpressionT:
def __rmod__(self, other: object) -> Remainder:
if not is_valid_operand(other):
return NotImplemented

return Remainder(other, self)

def __pow__(self, other: object) -> ArithmeticExpressionT:
def __pow__(self, other: object) -> Power:
if not is_valid_operand(other):
return NotImplemented

other = cast(NumberT, other)
if is_zero(other): # exponent zero
return 1
elif is_zero(other-1): # exponent one
return self
return Power(self, other)

def __rpow__(self, other: object) -> ArithmeticExpressionT:
assert is_constant(other)
def __rpow__(self, other: object) -> Power:
if not is_valid_operand(other):
return NotImplemented

if is_zero(other): # base zero
return 0
elif is_zero(other-1): # base one
return 1
return Power(other, self)

# }}}
Expand Down

0 comments on commit c8e444f

Please sign in to comment.