From 76431928893d93ec3b651bc3bbe0006089c9a920 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 20 Dec 2024 21:07:33 -0600 Subject: [PATCH] Fix FormSum weights --- ufl/action.py | 21 +++++++++++++++------ ufl/adjoint.py | 2 +- ufl/algorithms/map_integrands.py | 2 +- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/ufl/action.py b/ufl/action.py index 14267d722..f744a8be6 100644 --- a/ufl/action.py +++ b/ufl/action.py @@ -69,12 +69,21 @@ def __new__(cls, *args, **kw): if isinstance(right, (Coargument, Argument)): return left - if isinstance(left, (FormSum, Sum)): - # Action distributes over sums on the LHS - return FormSum(*[(Action(component, right), 1) for component in left.ufl_operands]) - if isinstance(right, (FormSum, Sum)): - # Action also distributes over sums on the RHS - return FormSum(*[(Action(left, component), 1) for component in right.ufl_operands]) + # Action distributes over sums on the LHS + if isinstance(left, Sum): + return FormSum(*((Action(component, right), 1) for component in left.ufl_operands)) + elif isinstance(left, FormSum): + return FormSum( + *((Action(c, right), w) for c, w in zip(left.components(), left.weights())) + ) + + # Action also distributes over sums on the RHS + if isinstance(right, Sum): + return FormSum(*((Action(left, component), 1) for component in right.ufl_operands)) + elif isinstance(right, FormSum): + return FormSum( + *((Action(left, c), w) for c, w in zip(right.components(), right.weights())) + ) return super(Action, cls).__new__(cls) diff --git a/ufl/adjoint.py b/ufl/adjoint.py index 71227f578..7f1b1a556 100644 --- a/ufl/adjoint.py +++ b/ufl/adjoint.py @@ -49,7 +49,7 @@ def __new__(cls, *args, **kw): return form._form elif isinstance(form, FormSum): # Adjoint distributes over sums - return FormSum(*[(Adjoint(component), 1) for component in form.components()]) + return FormSum(*((Adjoint(c), w) for c, w in zip(form.components(), form.weights()))) elif isinstance(form, Coargument): # The adjoint of a coargument `c: V* -> V*` is the identity # matrix mapping from V to V (i.e. V x V* -> R). diff --git a/ufl/algorithms/map_integrands.py b/ufl/algorithms/map_integrands.py index 0a6da1817..71f7d8ff2 100644 --- a/ufl/algorithms/map_integrands.py +++ b/ufl/algorithms/map_integrands.py @@ -59,7 +59,7 @@ def map_integrands(function, form, only_integral_type=None): # Simplification of `BaseForm` objects may turn `FormSum` into a sum of `Expr` objects # that are not `BaseForm`, i.e. into a `Sum` object. # Example: `Action(Adjoint(c*), u)` with `c*` a `Coargument` and u a `Coefficient`. - return sum([component for component, _ in nonzero_components]) + return sum(component * w for component, w in nonzero_components) return FormSum(*nonzero_components) elif isinstance(form, Adjoint): # Zeros are caught inside `Adjoint.__new__`