diff --git a/galgebra/_full_set.py b/galgebra/_full_set.py new file mode 100644 index 00000000..3489724d --- /dev/null +++ b/galgebra/_full_set.py @@ -0,0 +1,50 @@ +class FullSet(set): + """ A set that contains everything. This is used to trick sympy. """ + def __contains__(self, x): + return True + + def __iter__(self): + raise RuntimeError("Set is infinite") + + def __len__(self): + raise RuntimeError("Set is infinite") + + def __and__(self, other): + if isinstance(other, set): + return other + return NotImplemented + __rand__ = __and__ + + def __or__(self, other): + if isinstance(other, set): + return self + return NotImplemented + __ror__ = __or__ + + def __gt__(self, other): + if isinstance(other, set): + return True + elif isinstance(other, FullSet): + return False + return NotImplemented + + def __lt__(self, other): + if isinstance(other, FullSet): + return False + return NotImplemented + + def __ge__(self, other): + return not (self < other) + + def __le__(self, other): + return not (self > other) + + def __eq__(self, other): + if isinstance(other, set): + return False + elif isinstance(other, FullSet): + return True + return NotImplemented + + def __bool__(self): + return True diff --git a/galgebra/dop.py b/galgebra/dop.py index 0fd07c60..d7eefdc3 100644 --- a/galgebra/dop.py +++ b/galgebra/dop.py @@ -3,16 +3,19 @@ For multivector-customized differential operators, see :class:`galgebra.mv.Dop`. """ +import abc import copy import numbers import warnings from typing import List, Tuple, Any, Iterable +import functools +import operator -from sympy import Symbol, S, Add, simplify, diff, Expr, Dummy +from sympy import Symbol, S, diff, Expr +import sympy from . import printer -from . import metric -from .printer import ZERO_STR +from . import _full_set def _consolidate_terms(terms): @@ -74,205 +77,94 @@ class _BaseDop(object): pass -class Sdop(_BaseDop): - """ - Scalar differential operator is of the form (Einstein summation) +class DiffOpExpr(Expr, _BaseDop): + free_symbols = _full_set.FullSet() + is_commutative = False + _op_priority = 50.0 - .. math:: D = c_{i}*D_{i} + def _diff_op_apply(self, x): + raise NotImplementedError - where the :math:`c_{i}`'s are scalar coefficient (they could be functions) - and the :math:`D_{i}`'s are partial differential operators (:class:`Pdop`). + def __add__(self, x): + return DiffOpAdd(self, x) - Attributes - ---------- - terms : tuple of tuple - the structure :math:`((c_{1},D_{1}),(c_{2},D_{2}), ...)` - """ + def __radd__(self, x): + return DiffOpAdd(x, self) - str_mode = False + def __mul__(self, x): + return DiffOpMul(self, x) - def TSimplify(self): - return Sdop([ - (metric.Simp.apply(coef), pdiff) for (coef, pdiff) in self.terms - ]) + def __rmul__(self, x): + return DiffOpMul(x, self) - @staticmethod - def consolidate_coefs(sdop): - """ - Remove zero coefs and consolidate coefs with repeated pdiffs. - """ - if isinstance(sdop, Sdop): - return Sdop(_consolidate_terms(sdop.terms)) - else: - return _consolidate_terms(sdop) - - def simplify(self, modes=simplify): - return Sdop([ - (metric.apply_function_list(modes, coef), pdiff) - for coef, pdiff in self.terms - ]) - - def _with_sorted_terms(self): - new_terms = sorted(self.terms, key=lambda term: Pdop.sort_key(term[1])) - return Sdop(new_terms) - - def Sdop_str(self): - if len(self.terms) == 0: - return ZERO_STR - - self = self._with_sorted_terms() - s = '' - for (coef, pdop) in self.terms: - coef_str = printer.latex(coef) - pd_str = printer.latex(pdop) - - if coef == S(1): - s += pd_str - elif coef == S(-1): - s += '-' + pd_str - else: - if isinstance(coef, Add): - s += '(' + coef_str + ')*' + pd_str - else: - s += coef_str + '*' + pd_str - s += ' + ' - - s = s.replace('+ -','- ') - s = s[:-3] - if Sdop.str_mode: - if len(self.terms) > 1 or isinstance(self.terms[0][0], Add): - s = '(' + s + ')' - return s + def __neg__(self): + return DiffOpMul(S.NegativeOne, self) - def Sdop_latex_str(self): - if len(self.terms) == 0: - return ZERO_STR - - self = self._with_sorted_terms() - - s = '' - for (coef, pdop) in self.terms: - coef_str = printer.latex(coef) - pd_str = printer.latex(pdop) - if coef == S(1): - if pd_str == '': - s += '1' - else: - s += pd_str - elif coef == S(-1): - if pd_str == '': - s += '-1' - else: - s += '-' + pd_str - else: - if isinstance(coef, Add): - s += r'\left ( ' + coef_str + r'\right ) ' + pd_str - else: - s += coef_str + ' ' + pd_str - s += ' + ' + def __sub__(self, x): + return DiffOpAdd(self, -x) - s = s.replace('+ -','- ') - return s[:-3] + def __rsub__(self, x): + return DiffOpAdd(x, -self) - def _repr_latex_(self): - latex_str = printer.GaLatexPrinter.latex(self) - return ' ' + latex_str + ' ' + def __call__(self, x): + return self._diff_op_apply(x) - def __str__(self): - if printer.GaLatexPrinter.latex_flg: - Printer = printer.GaLatexPrinter - else: - Printer = printer.GaPrinter - return Printer().doprint(self) +def _diff_op_ify(x): + if isinstance(x, DiffOpExpr): + return x + elif isinstance(x, sympy.Add): + return DiffOpAdd(*(a for a in x.args)) + elif isinstance(x, sympy.Mul): + return DiffOpMul(*(a for a in x.args)) + else: + return x * DiffOpPartial({}) + + +def _diff_op_apply(d, x): + if not isinstance(d, DiffOpExpr): + d = d * DiffOpPartial({}) + return _diff_op_ify(d)._diff_op_apply(x) - def __repr__(self): - return str(self) - def __init_from_symbol(self, symbol: Symbol) -> None: - self.terms = ((S(1), Pdop(symbol)),) +class Sdop(abc.ABC): + @classmethod + def __subclasshook__(cls, c): + return issubclass(c, DiffOpExpr) - def __init_from_coef_and_pdiffs(self, coefs: List[Any], pdiffs: List['Pdop']) -> None: + @classmethod + def _from_symbol(cls, symbol: Symbol) -> 'DiffOpExpr': + return Pdop(symbol) + + @classmethod + def _from_coef_and_pdiffs(cls, coefs: List[Any], pdiffs: List['Pdop']) -> None: if not isinstance(coefs, list) or not isinstance(pdiffs, list): raise TypeError("coefs and pdiffs must be lists") if len(coefs) != len(pdiffs): raise ValueError('In Sdop.__init__ coefficent list and Pdop list must be same length.') - self.terms = tuple(zip(coefs, pdiffs)) + return cls._from_terms(tuple(zip(coefs, pdiffs))) - def __init_from_terms(self, terms: Iterable[Tuple[Any, 'Pdop']]) -> None: - self.terms = tuple(terms) + @classmethod + def _from_terms(cls, terms: Iterable[Tuple[Any, 'Pdop']]) -> None: + return sum((a * b for a, b in terms), DiffOpAdd.identity) - def __init__(self, *args): + def __new__(cls, *args): if len(args) == 1: if isinstance(args[0], Symbol): - self.__init_from_symbol(*args) + return cls._from_symbol(*args) elif isinstance(args[0], (list, tuple)): - self.__init_from_terms(*args) + return cls._from_terms(*args) else: raise TypeError( "A symbol or sequence is required (got type {})" .format(type(args[0]).__name__)) elif len(args) == 2: - self.__init_from_coef_and_pdiffs(*args) + return cls._from_coef_and_pdiffs(*args) else: raise TypeError( "Sdop() takes from 1 to 2 positional arguments but {} were " "given".format(len(args))) - def __call__(self, arg): - # Ensure that we return the right type even when there are no terms - we - # do this by adding `0 * d(arg)/d(nonexistant)`, which must be zero, but - # will be a zero of the right type. - dummy_var = Dummy('nonexistant') - terms = self.terms or ((S(0), Pdop(dummy_var)),) - return sum([coef * pdiff(arg) for coef, pdiff in terms]) - - def __neg__(self): - return Sdop([(-coef, pdiff) for coef, pdiff in self.terms]) - - @staticmethod - def Add(sdop1, sdop2): - if isinstance(sdop1, Sdop) and isinstance(sdop2, Sdop): - return Sdop(_merge_terms(sdop1.terms, sdop2.terms)) - else: - # convert values to multiplicative operators - if not isinstance(sdop2, _BaseDop): - sdop2 = Sdop([(sdop2, Pdop({}))]) - elif not isinstance(sdop1, _BaseDop): - sdop1 = Sdop([(sdop1, Pdop({}))]) - else: - return NotImplemented - return Sdop.Add(sdop1, sdop2) - - def __eq__(self, other): - if isinstance(other, Sdop): - diff = self - other - return len(diff.terms) == 0 - else: - return NotImplemented - - def __add__(self, sdop): - return Sdop.Add(self, sdop) - - def __radd__(self, sdop): - return Sdop.Add(sdop, self) - - def __sub__(self, sdop): - return Sdop.Add(self, -sdop) - - def __rsub__(self, sdop): - return Sdop.Add(-self, sdop) - - def __mul__(self, sdopr): - # alias for applying the operator - return self.__call__(sdopr) - - def __rmul__(self, sdop): - return Sdop([(sdop * coef, pdiff) for coef, pdiff in self.terms]) - - def _eval_derivative_n_times(self, x, n): - return Sdop(_eval_derivative_n_times_terms(self.terms, x, n)) - #################### Partial Derivative Operator Class ################# @@ -288,7 +180,7 @@ def _basic_diff(f, x, n=1): raise ValueError('In_basic_diff type(arg) = ' + str(type(f)) + ' not allowed.') -class Pdop(_BaseDop): +class DiffOpPartial(DiffOpExpr, sympy.AtomicExpr): r""" Partial derivative operatorp. @@ -317,22 +209,14 @@ def sort_key(self, order=None): # lower order derivatives first self.order, # sorted by symbol after that, after expansion - sorted([ + tuple(sorted([ x.sort_key(order) for x, k in self.pdiffs.items() for i in range(k) - ]) + ])) ) - def __eq__(self,A): - if isinstance(A, Pdop) and self.pdiffs == A.pdiffs: - return True - else: - if len(self.pdiffs) == 0 and A == S(1): - return True - return False - - def __init__(self, __arg): + def __new__(cls, __arg): """ The partial differential operator is a partial derivative with respect to a set of real symbols (variables). @@ -345,13 +229,19 @@ def __init__(self, __arg): __arg = {} if isinstance(__arg, dict): # Pdop defined by dictionary - self.pdiffs = __arg + pdiffs = __arg elif isinstance(__arg, Symbol): # First order derivative with respect to symbol - self.pdiffs = {__arg: 1} + pdiffs = {__arg: 1} else: raise TypeError('A dictionary or symbol is required, got {!r}'.format(__arg)) + self = super().__new__(cls) + self.pdiffs = pdiffs self.order = sum(self.pdiffs.values()) + return self + + def _eval_derivative(self, x): + return self._eval_derivative_n_times(x, 1) def _eval_derivative_n_times(self, x, n) -> 'Pdop': # pdiff(self) # d is partial derivative @@ -362,7 +252,7 @@ def _eval_derivative_n_times(self, x, n) -> 'Pdop': # pdiff(self) pdiffs[x] = n return Pdop(pdiffs) - def __call__(self, arg): + def _diff_op_apply(self, arg): """ Calculate nth order partial derivative (order defined by self) of expression @@ -371,13 +261,6 @@ def __call__(self, arg): arg = _basic_diff(arg, x, n) return arg - def __mul__(self, other): # functional product of self and arg (self*arg) - return self(other) - - def __rmul__(self, other): # functional product of arg and self (arg*self) - assert not isinstance(other, Pdop) - return Sdop([(other, self)]) - def Pdop_str(self): if self.order == 0: return 'D{}' @@ -419,3 +302,135 @@ def __str__(self): def __repr__(self): return str(self) + + def __srepr__(self): + return '{}({})'.format(type(self).__name__, self.pdiffs) + + def _hashable_content(self): + from sympy.utilities import default_sort_key + sorted_items = sorted( + self.pdiffs.items(), + key=lambda t: (default_sort_key(t[0]), t[1]) + ) + return tuple(sympy.Basic(coeff, sympy.S(n)) for coeff, n in sorted_items) + + +Pdop = DiffOpPartial + + +class DiffOpZero(DiffOpExpr): + def _diff_op_apply(self, x): + return 0 * x + + +class DiffOpMul(DiffOpExpr, sympy.Mul): + identity = DiffOpPartial({}) + + def __new__(cls, *args, **kwargs): + if not args: + return cls.identity + pre_coeffs = [] + it_args = iter(args) + + pre_coeffs = [] + diff_ops = [] + diff_operands = [] + + # extra pre-multiplied coeffs + for a in it_args: + if isinstance(a, DiffOpExpr): + diff_ops.append(a) + break + pre_coeffs.append(a) + + # extract differential terms + for a in it_args: + if not isinstance(a, DiffOpExpr): + diff_operands.append(a) + break + diff_ops.append(a) + + # must be only one operand + for a in it_args: + raise TypeError( + "Must pass at most one operand after the differential operators" + ) + + # avoid `sympy.Mul` so that this works on multivectors + if pre_coeffs: + coeff = functools.reduce(operator.mul, pre_coeffs) + else: + coeff = S(1) + + d = cls.identity + for di in diff_ops: + d = d._diff_op_apply(di) + if coeff == S(1): + self = d + elif coeff == S(0): + self = DiffOpZero() + else: + self = sympy.Basic.__new__(cls, coeff, d, **kwargs) + + if diff_operands: + return self._diff_op_apply(diff_operands[0]) + + return self + + def _diff_op_apply(self, x): + coeff, *rest = self.args + for r in rest[::-1]: + x = r._diff_op_apply(x) + return coeff * x + + def diff(self, *args, **kwargs): + return super().diff(*args, simplify=False, **kwargs) + + def _eval_derivative(self, x): + coeff, diff = self.args + return sympy.diff(coeff, x) * diff + coeff * sympy.diff(diff, x) + + +class DiffOpAdd(DiffOpExpr, sympy.Add): + identity = DiffOpZero() + + @classmethod + def _from_args(self, args, is_commutative): + args = [_diff_op_ify(arg) for arg in args] + return super()._from_args(args, is_commutative) + + def _diff_op_apply(self, x): + args = self.args + assert args + # avoid `sympy.Add` so that this works on multivectors + return functools.reduce(operator.add, (_diff_op_apply(a, x) for a in args)) + + def _eval_derivative_n_times(self, x, n): + return DiffOpAdd(*(a.diff(x, n) for a in self.args)) + + def _eval_derivative(self, x): + return DiffOpAdd(*(a.diff(x) for a in self.args)) + + def _eval_simplify(self, *args, **kwargs): + return self + + def diff(self, *args, **kwargs): + return super().diff(*args, simplify=False, **kwargs) + + +def _as_terms(d): + d = d.expand() + if isinstance(d, DiffOpAdd): + for a in d.args: + yield from _as_terms(a) + elif isinstance(d, DiffOpMul): + coeff, pdiff = d.args + yield (coeff, pdiff) + elif isinstance(d, DiffOpPartial): + yield (S(1), d) + elif isinstance(d, DiffOpZero): + pass + elif isinstance(d, sympy.Atom): + yield (d, DiffOpMul.identity) + else: + raise NotImplementedError("cannot convert {} to terms".format(sympy.srepr(d))) diff --git a/galgebra/mv.py b/galgebra/mv.py index 7ae9a77e..cc80f1a5 100644 --- a/galgebra/mv.py +++ b/galgebra/mv.py @@ -60,6 +60,14 @@ class Mv(object): ################### Multivector initialization ##################### + is_number = False + is_Number = False + is_Rational = False + is_commutative = False + + def sort_key(self, *args, **kwargs): + return self.obj.sort_key(*args, **kwargs) + fmt = 1 latex_flg = False restore = False @@ -1462,7 +1470,7 @@ def __init_from_terms(self, terms: Union[ self.terms = dop._consolidate_terms( (coef * mv, pdiff) for (sdop, mv) in terms - for (coef, pdiff) in sdop.terms + for (coef, pdiff) in dop._terms_of(sdop) ) else: raise TypeError( @@ -1723,7 +1731,7 @@ def Dop_mv_expand(self, modes=None): coefs.append(dop.Sdop([(mv_coef, pdiff)])) if modes is not None: for i in range(len(coefs)): - coefs[i] = coefs[i].simplify(modes) + coefs[i] = coefs[i].simplify(modes=modes) terms = list(zip(coefs, bases)) return sorted(terms, key=lambda x: self.Ga._all_blades_lst.index(x[1])) @@ -1740,13 +1748,14 @@ def Dop_str(self): if base == S(1): s += str_sdop else: - if len(sdop.terms) > 1: + terms = list(dop._as_terms(sdop)) + if len(terms) > 1: if self.cmpflg: s += '(' + str_sdop + ')*' + str_base else: s += str_base + '*(' + str_sdop + ')' else: - if str_sdop[0] == '-' and not isinstance(sdop.terms[0][0], Add): + if str_sdop[0] == '-' and not isinstance(terms[0], Add): if self.cmpflg: s += str_sdop + '*' + str_base else: @@ -1783,13 +1792,14 @@ def Dop_latex_str(self): if str_sdop[1:] != '1': s += ' ' + str_sdop[1:] else: - if len(sdop.terms) > 1: + terms = list(dop._as_terms(sdop)) + if len(terms) > 1: if self.cmpflg: s += r'\left ( ' + str_sdop + r'\right ) ' + str_base else: s += str_base + ' ' + r'\left ( ' + str_sdop + r'\right ) ' else: - if str_sdop[0] == '-' and not isinstance(sdop.terms[0][0], Add): + if str_sdop[0] == '-' and not isinstance(terms[0][0], Add): if self.cmpflg: s += str_sdop + str_base else: diff --git a/galgebra/printer.py b/galgebra/printer.py index 8c991989..821bd9f5 100644 --- a/galgebra/printer.py +++ b/galgebra/printer.py @@ -371,7 +371,7 @@ def _print_Mv(self, expr): else: return expr.Mv_str() - def _print_Pdop(self, expr): + def _print_DiffOpPartial(self, expr): return expr.Pdop_str() def _print_Dop(self, expr):