diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 909db6990a..aaf4239556 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -83,6 +83,7 @@ jobs: --install defcon \ --install gadopt \ --install asQ \ + --package-branch pyadjoint dham/abstract_reduced_functional || (cat firedrake-install.log && /bin/false) - name: Install test dependencies run: | diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index e4664665b0..4c4020a249 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -86,8 +86,15 @@ def _init_solver_parameters(self, args, kwargs): self.assemble_kwargs = {} def __str__(self): - return "solve({} = {})".format(ufl2unicode(self.lhs), - ufl2unicode(self.rhs)) + try: + lhs_string = ufl2unicode(self.lhs) + except AttributeError: + lhs_string = str(self.lhs) + try: + rhs_string = ufl2unicode(self.rhs) + except AttributeError: + rhs_string = str(self.rhs) + return "solve({} = {})".format(lhs_string, rhs_string) def _create_F_form(self): # Process the equation forms, replacing values with checkpoints, @@ -756,7 +763,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, c = block_variable.output c_rep = block_variable.saved_output - if isinstance(c, firedrake.Function): + if isinstance(c, (firedrake.Function, firedrake.Cofunction)): trial_function = firedrake.TrialFunction(c.function_space()) elif isinstance(c, firedrake.Constant): mesh = F_form.ufl_domain() @@ -793,7 +800,11 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, replace_map[self.func] = self.get_outputs()[0].saved_output dFdm = replace(dFdm, replace_map) - dFdm = dFdm * adj_sol + if isinstance(dFdm, firedrake.Argument): + # Corner case. Should be fixed more permanently upstream in UFL. + dFdm = ufl.Action(dFdm, adj_sol) + else: + dFdm = dFdm * adj_sol dFdm = firedrake.assemble(dFdm, **self.assemble_kwargs) return dFdm diff --git a/firedrake/adjoint_utils/function.py b/firedrake/adjoint_utils/function.py index 5e87751d36..f593bf605d 100644 --- a/firedrake/adjoint_utils/function.py +++ b/firedrake/adjoint_utils/function.py @@ -2,7 +2,7 @@ import ufl from ufl.domain import extract_unique_domain from pyadjoint.overloaded_type import create_overloaded_object, FloatingType -from pyadjoint.tape import annotate_tape, stop_annotating, get_working_tape, no_annotations +from pyadjoint.tape import annotate_tape, stop_annotating, get_working_tape from firedrake.adjoint_utils.blocks import FunctionAssignBlock, ProjectBlock, SubfunctionBlock, FunctionMergeBlock, SupermeshProjectBlock import firedrake from .checkpointing import disk_checkpointing, CheckpointFunction, \ @@ -220,72 +220,15 @@ def _ad_create_checkpoint(self): else: return self.copy(deepcopy=True) - def _ad_convert_riesz(self, value, options=None): - from firedrake import Function, Cofunction - - options = {} if options is None else options - riesz_representation = options.get("riesz_representation", "L2") - solver_options = options.get("solver_options", {}) - V = options.get("function_space", self.function_space()) - if value == 0.: - # In adjoint-based differentiation, value == 0. arises only when - # the functional is independent on the control variable. - return Function(V) - - if not isinstance(value, (Cofunction, Function)): - raise TypeError("Expected a Cofunction or a Function") - - if riesz_representation == "l2": - return Function(V, val=value.dat) - - elif riesz_representation in ("L2", "H1"): - if not isinstance(value, Cofunction): - raise TypeError("Expected a Cofunction") - - ret = Function(V) - a = self._define_riesz_map_form(riesz_representation, V) - firedrake.solve(a == value, ret, **solver_options) - return ret - - elif callable(riesz_representation): - return riesz_representation(value) - - else: - raise ValueError( - "Unknown Riesz representation %s" % riesz_representation) + def _ad_convert_riesz(self, value, riesz_map=None): + return value.riesz_representation(riesz_map=riesz_map or "L2") - def _define_riesz_map_form(self, riesz_representation, V): - from firedrake import TrialFunction, TestFunction - - u = TrialFunction(V) - v = TestFunction(V) - if riesz_representation == "L2": - a = firedrake.inner(u, v)*firedrake.dx - - elif riesz_representation == "H1": - a = firedrake.inner(u, v)*firedrake.dx \ - + firedrake.inner(firedrake.grad(u), firedrake.grad(v))*firedrake.dx - - else: - raise NotImplementedError( - "Unknown Riesz representation %s" % riesz_representation) - return a - - @no_annotations - def _ad_convert_type(self, value, options=None): - # `_ad_convert_type` is not annotated, unlike `_ad_convert_riesz` - options = {} if options is None else options.copy() - options.setdefault("riesz_representation", "L2") - if options["riesz_representation"] is None: - if value == 0.: - # In adjoint-based differentiation, value == 0. arises only when - # the functional is independent on the control variable. - V = options.get("function_space", self.function_space()) - return firedrake.Cofunction(V.dual()) - else: - return value + def _ad_init_zero(self, dual=False): + from firedrake import Function, Cofunction + if dual: + return Cofunction(self.function_space().dual()) else: - return self._ad_convert_riesz(value, options=options) + return Function(self.function_space()) def _ad_restore_at_checkpoint(self, checkpoint): if isinstance(checkpoint, CheckpointBase): @@ -294,9 +237,7 @@ def _ad_restore_at_checkpoint(self, checkpoint): return checkpoint def _ad_will_add_as_dependency(self): - """Method called when the object is added as a Block dependency. - - """ + """Method called when the object is added as a Block dependency.""" with checkpoint_init_data(): super()._ad_will_add_as_dependency() @@ -304,7 +245,8 @@ def _ad_mul(self, other): from firedrake import Function r = Function(self.function_space()) - # `self` can be a Cofunction in which case only left multiplication with a scalar is allowed. + # `self` can be a Cofunction in which case only left multiplication + # with a scalar is allowed. r.assign(other * self) return r @@ -316,7 +258,10 @@ def _ad_add(self, other): return r def _ad_dot(self, other, options=None): - from firedrake import assemble + from firedrake import assemble, action, Cofunction + + if isinstance(other, Cofunction): + return assemble(action(other, self)) options = {} if options is None else options riesz_representation = options.get("riesz_representation", "L2") @@ -406,3 +351,9 @@ def _ad_to_petsc(self, vec=None): def __deepcopy__(self, memodict={}): return self.copy(deepcopy=True) + + +class CofunctionMixin(FunctionMixin): + + def _ad_dot(self, other): + return firedrake.assemble(firedrake.action(self, other)) diff --git a/firedrake/cofunction.py b/firedrake/cofunction.py index 4878e6da59..6e8bebe8bd 100644 --- a/firedrake/cofunction.py +++ b/firedrake/cofunction.py @@ -1,3 +1,4 @@ +from functools import cached_property import numpy as np import ufl @@ -8,16 +9,16 @@ import firedrake.functionspaceimpl as functionspaceimpl from firedrake import utils, vector, ufl_expr from firedrake.utils import ScalarType -from firedrake.adjoint_utils.function import FunctionMixin +from firedrake.adjoint_utils.function import CofunctionMixin from firedrake.adjoint_utils.checkpointing import DelegatedFunctionCheckpoint from firedrake.adjoint_utils.blocks.function import CofunctionAssignBlock from firedrake.petsc import PETSc -class Cofunction(ufl.Cofunction, FunctionMixin): +class Cofunction(ufl.Cofunction, CofunctionMixin): r"""A :class:`Cofunction` represents a function on a dual space. - Like Functions, cofunctions are - represented as sums of basis functions: + + Like Functions, cofunctions are represented as sums of basis functions: .. math:: @@ -33,7 +34,7 @@ class Cofunction(ufl.Cofunction, FunctionMixin): """ @PETSc.Log.EventDecorator() - @FunctionMixin._ad_annotate_init + @CofunctionMixin._ad_annotate_init def __init__(self, function_space, val=None, name=None, dtype=ScalarType, count=None): r""" @@ -105,13 +106,13 @@ def _analyze_form_arguments(self): self._coefficients = (self,) @utils.cached_property - @FunctionMixin._ad_annotate_subfunctions + @CofunctionMixin._ad_annotate_subfunctions def subfunctions(self): r"""Extract any sub :class:`Cofunction`\s defined on the component spaces of this this :class:`Cofunction`'s :class:`.FunctionSpace`.""" return tuple(type(self)(fs, dat) for fs, dat in zip(self.function_space(), self.dat)) - @FunctionMixin._ad_annotate_subfunctions + @CofunctionMixin._ad_annotate_subfunctions def split(self): import warnings warnings.warn("The .split() method is deprecated, please use the .subfunctions property instead", category=FutureWarning) @@ -228,39 +229,45 @@ def assign(self, expr, subset=None, expr_from_assemble=False): raise ValueError('Cannot assign %s' % expr) - def riesz_representation(self, riesz_map='L2', **solver_options): - """Return the Riesz representation of this :class:`Cofunction` with respect to the given Riesz map. + def riesz_representation(self, riesz_map='L2', *, bcs=None, + solver_options=None, + form_compiler_parameters=None): + """Return the Riesz representation of this :class:`Cofunction`. - Example: For a L2 Riesz map, the Riesz representation is obtained by solving - the linear system ``Mx = self``, where M is the L2 mass matrix, i.e. M = - with u and v trial and test functions, respectively. + Example: For a L2 Riesz map, the Riesz representation is obtained by + solving the linear system ``Mx = self``, where M is the L2 mass matrix, + i.e. M = with u and v trial and test functions, respectively. Parameters ---------- - riesz_map : str or collections.abc.Callable - The Riesz map to use (`l2`, `L2`, or `H1`). This can also be a callable. - solver_options : dict - Solver options to pass to the linear solver: - - solver_parameters: optional solver parameters. - - nullspace: an optional :class:`.VectorSpaceBasis` (or :class:`.MixedVectorSpaceBasis`) - spanning the null space of the operator. - - transpose_nullspace: as for the nullspace, but used to make the right hand side consistent. - - near_nullspace: as for the nullspace, but used to add the near nullspace. - - options_prefix: an optional prefix used to distinguish PETSc options. - If not provided a unique prefix will be created. - Use this option if you want to pass options to the solver from the command line - in addition to through the ``solver_parameters`` dict. + riesz_map : str or ufl.sobolevspace.SobolevSpace or + collections.abc.Callable + The Riesz map to use (`l2`, `L2`, or `H1`). This can also be a + callable. + bcs: DirichletBC or list of DirichletBC + Boundary conditions to apply to the Riesz map. + solver_options: dict + A dictionary of PETSc options to be passed to the solver. + form_compiler_parameters: dict + A dictionary of form compiler parameters to be passed to the + variational problem that solves for the Riesz map. Returns ------- firedrake.function.Function - Riesz representation of this :class:`Cofunction` with respect to the given Riesz map. + Riesz representation of this :class:`Cofunction` with respect to + the given Riesz map. """ - return self._ad_convert_riesz(self, options={"function_space": self.function_space().dual(), - "riesz_representation": riesz_map, - "solver_options": solver_options}) + if not callable(riesz_map): + riesz_map = RieszMap( + self.function_space(), riesz_map, bcs=bcs, + solver_parameters=solver_options, + form_compiler_parameters=form_compiler_parameters + ) + + return riesz_map(self) - @FunctionMixin._ad_annotate_iadd + @CofunctionMixin._ad_annotate_iadd @utils.known_pyop2_safe def __iadd__(self, expr): @@ -276,7 +283,7 @@ def __iadd__(self, expr): # Let Python hit `BaseForm.__add__` which relies on ufl.FormSum. return NotImplemented - @FunctionMixin._ad_annotate_isub + @CofunctionMixin._ad_annotate_isub @utils.known_pyop2_safe def __isub__(self, expr): @@ -293,7 +300,7 @@ def __isub__(self, expr): # Let Python hit `BaseForm.__sub__` which relies on ufl.FormSum. return NotImplemented - @FunctionMixin._ad_annotate_imul + @CofunctionMixin._ad_annotate_imul def __imul__(self, expr): if np.isscalar(expr): @@ -360,3 +367,126 @@ def __str__(self): def cell_node_map(self): return self.function_space().cell_node_map() + + +class RieszMap: + """Return a map between dual and primal function spaces. + + A `RieszMap` can be called on a `Cofunction` in the appropriate space to + yield the `Function` which is the Riesz representer under the given inner + product. Conversely, it can be called on a `Function` to apply the given + inner product and return a `Cofunction`. + + Parameters + ---------- + function_space_or_inner_product: FunctionSpace or ufl.Form + The space from which to map, or a bilinear form defining an inner + product. + sobolev_space: str or ufl.sobolevspace.SobolevSpace. + Used to determine the inner product. + bcs: DirichletBC or list of DirichletBC + Boundary conditions to apply to the Riesz map. + solver_parameters: dict + A dictionary of PETSc options to be passed to the solver. + form_compiler_parameters: dict + A dictionary of form compiler parameters to be passed to the + variational problem that solves for the Riesz map. + """ + + def __init__(self, function_space_or_inner_product=None, + sobolev_space=ufl.L2, *, bcs=None, solver_parameters=None, + form_compiler_parameters=None): + if isinstance(function_space_or_inner_product, ufl.Form): + args = ufl.algorithms.extract_arguments( + function_space_or_inner_product + ) + if len(args) != 2: + raise ValueError(f"inner_product has arity {len(args)}, " + "should be 2.") + function_space = args[0].function_space() + inner_product = function_space_or_inner_product + else: + function_space = function_space_or_inner_product + if hasattr(function_space, "function_space"): + function_space = function_space.function_space() + if ufl.duals.is_dual(function_space): + function_space = function_space.dual() + + if str(sobolev_space) == "l2": + inner_product = "l2" + else: + from firedrake import TrialFunction, TestFunction + u = TrialFunction(function_space) + v = TestFunction(function_space) + inner_product = RieszMap._inner_product_form( + sobolev_space, u, v + ) + + self._function_space = function_space + self._inner_product = inner_product + self._bcs = bcs + self._solver_parameters = solver_parameters or {} + self._form_compiler_parameters = form_compiler_parameters or {} + + @staticmethod + def _inner_product_form(sobolev_space, u, v): + from firedrake import inner, dx, grad + inner_products = { + "L2": lambda u, v: inner(u, v)*dx, + "H1": lambda u, v: inner(u, v)*dx + inner(grad(u), grad(v))*dx + } + try: + return inner_products[str(sobolev_space)](u, v) + except KeyError: + raise ValueError("No inner product defined for Sobolev space " + f"{sobolev_space}.") + + @cached_property + def _solver(self): + from firedrake import (LinearVariationalSolver, + LinearVariationalProblem, Function, Cofunction) + rhs = Cofunction(self._function_space.dual()) + soln = Function(self._function_space) + lvp = LinearVariationalProblem( + self._inner_product, rhs, soln, bcs=self._bcs, restrict=True, + form_compiler_parameters=self._form_compiler_parameters) + solver = LinearVariationalSolver( + lvp, solver_parameters=self._solver_parameters + ) + return solver.solve, rhs, soln + + def __call__(self, value): + """Return the Riesz representer of a Function or Cofunction.""" + from firedrake import Function, Cofunction + + if ufl.duals.is_dual(value): + if value.function_space().dual() != self._function_space: + raise ValueError("Function space mismatch in RieszMap.") + output = Function(self._function_space) + + if self._inner_product == "l2": + for o, c in zip(output.subfunctions, value.subfunctions): + o.dat.data[:] = c.dat.data[:] + else: + solve, rhs, soln = self._solver + rhs.assign(value) + solve() + output = Function(self._function_space) + output.assign(soln) + elif ufl.duals.is_primal(value): + if value.function_space().dual() != self._function_space: + raise ValueError("Function space mismatch in RieszMap.") + + if self._inner_product == "l2": + output = Cofunction(self._function_space.dual()) + for o, c in zip(output.subfunctions, value.subfunctions): + o.dat.data[:] = c.dat.data[:] + else: + output = firedrake.assemble( + firedrake.action(self._inner_product, value) + ) + else: + raise ValueError( + f"Unable to ascertain if {value} is primal or dual." + ) + return output diff --git a/firedrake/function.py b/firedrake/function.py index da4d264971..5aae0fe561 100644 --- a/firedrake/function.py +++ b/firedrake/function.py @@ -18,7 +18,7 @@ from firedrake.utils import ScalarType, IntType, as_ctypes from firedrake import functionspaceimpl -from firedrake.cofunction import Cofunction +from firedrake.cofunction import Cofunction, RieszMap from firedrake import utils from firedrake import vector from firedrake.adjoint_utils import FunctionMixin @@ -479,39 +479,29 @@ def assign(self, expr, subset=None): return self def riesz_representation(self, riesz_map='L2'): - """Return the Riesz representation of this :class:`Function` with respect to the given Riesz map. + """Return the Riesz representation of this :class:`Function`. - Example: For a L2 Riesz map, the Riesz representation is obtained by taking the action - of ``M`` on ``self``, where M is the L2 mass matrix, i.e. M = - with u and v trial and test functions, respectively. + Example: For a L2 Riesz map, the Riesz representation is obtained by + taking the action of ``M`` on ``self``, where M is the L2 mass matrix, + i.e. M = with u and v trial and test functions, respectively. Parameters ---------- - riesz_map : str or collections.abc.Callable - The Riesz map to use (`l2`, `L2`, or `H1`). This can also be a callable. + riesz_map : str or ufl.sobolevspace.SobolevSpace or + collections.abc.Callable + The Riesz map to use (`l2`, `L2`, or `H1`). This can also be a + callable which applies the Riesz map. Returns ------- firedrake.cofunction.Cofunction - Riesz representation of this :class:`Function` with respect to the given Riesz map. + Riesz representation of this :class:`Function` with respect to the + given Riesz map. """ - from firedrake.ufl_expr import action - from firedrake.assemble import assemble + if not callable(riesz_map): + riesz_map = RieszMap(self.function_space(), riesz_map) - V = self.function_space() - if riesz_map == "l2": - return Cofunction(V.dual(), val=self.dat) - - elif riesz_map in ("L2", "H1"): - a = self._define_riesz_map_form(riesz_map, V) - return assemble(action(a, self)) - - elif callable(riesz_map): - return riesz_map(self) - - else: - raise NotImplementedError( - "Unknown Riesz representation %s" % riesz_map) + return riesz_map(self) @FunctionMixin._ad_annotate_iadd def __iadd__(self, expr): diff --git a/firedrake/ufl_expr.py b/firedrake/ufl_expr.py index 1f9a66df75..f7108a9cc7 100644 --- a/firedrake/ufl_expr.py +++ b/firedrake/ufl_expr.py @@ -45,6 +45,12 @@ def __init__(self, function_space, number, part=None): number, part=part) self._function_space = function_space + def arguments(self): + return (self,) + + def coefficients(self): + return () + @utils.cached_property def cell_node_map(self): return self.function_space().cell_node_map diff --git a/pyop2/types/dataset.py b/pyop2/types/dataset.py index 087d9b091d..4a26cfc6d1 100644 --- a/pyop2/types/dataset.py +++ b/pyop2/types/dataset.py @@ -211,6 +211,7 @@ def __init__(self, global_): self._globalset = GlobalSet(comm=self.comm) self._name = "gdset_#x%x" % id(self) self._initialized = True + self._apply_local_global_filter = False @classmethod def _cache_key(cls, *args): diff --git a/tests/firedrake/regression/test_adjoint_operators.py b/tests/firedrake/regression/test_adjoint_operators.py index 8d4c0ab796..540932913b 100644 --- a/tests/firedrake/regression/test_adjoint_operators.py +++ b/tests/firedrake/regression/test_adjoint_operators.py @@ -864,10 +864,10 @@ def test_assign_zero_cofunction(): J = assemble(((sol + Constant(1.0)) ** 2) * dx) # The zero assignment should break the tape and hence cause a zero # gradient. - grad_l2 = compute_gradient(J, Control(k), options={"riesz_representation": "l2"}) - grad_none = compute_gradient(J, Control(k), options={"riesz_representation": None}) - grad_h1 = compute_gradient(J, Control(k), options={"riesz_representation": "H1"}) - grad_L2 = compute_gradient(J, Control(k), options={"riesz_representation": "L2"}) + grad_l2 = compute_gradient(J, Control(k, riesz_map="l2"), apply_riesz=True) + grad_none = compute_gradient(J, Control(k)) + grad_h1 = compute_gradient(J, Control(k, riesz_map="H1"), apply_riesz=True) + grad_L2 = compute_gradient(J, Control(k, riesz_map="L2"), apply_riesz=True) assert isinstance(grad_l2, Function) and isinstance(grad_L2, Function) \ and isinstance(grad_h1, Function) assert isinstance(grad_none, Cofunction) @@ -913,7 +913,6 @@ def test_riesz_representation_for_adjoints(): space = FunctionSpace(mesh, "Lagrange", 1) f = Function(space).interpolate(SpatialCoordinate(mesh)[0]) J = assemble((f ** 2) * dx) - rf = ReducedFunctional(J, Control(f)) with stop_annotating(): v = TestFunction(space) u = TrialFunction(space) @@ -932,21 +931,27 @@ def test_riesz_representation_for_adjoints(): dJdu_function_L2 = Function(space) solve(a == dJdu_cofunction, dJdu_function_L2) - dJdu_none = rf.derivative(options={"riesz_representation": None}) - dJdu_l2 = rf.derivative(options={"riesz_representation": "l2"}) - dJdu_H1 = rf.derivative(options={"riesz_representation": "H1"}) - dJdu_L2 = rf.derivative(options={"riesz_representation": "L2"}) - dJdu_default_L2 = rf.derivative() - assert ( - isinstance(dJdu_none, Cofunction) and isinstance(dJdu_function_l2, Function) - and isinstance(dJdu_H1, Function) and isinstance(dJdu_default_L2, Function) - and isinstance(dJdu_L2, Function) - and np.allclose(dJdu_none.dat.data, dJdu_cofunction.dat.data) - and np.allclose(dJdu_l2.dat.data, dJdu_function_l2.dat.data) - and np.allclose(dJdu_H1.dat.data, dJdu_function_H1.dat.data) - and np.allclose(dJdu_default_L2.dat.data, dJdu_function_L2.dat.data) - and np.allclose(dJdu_L2.dat.data, dJdu_function_L2.dat.data) + dJdu_none = ReducedFunctional(J, Control(f)).derivative() + dJdu_l2 = ReducedFunctional(J, Control(f, riesz_map="l2")).derivative( + apply_riesz=True + ) + dJdu_H1 = ReducedFunctional(J, Control(f, riesz_map="H1")).derivative( + apply_riesz=True + ) + dJdu_L2 = ReducedFunctional(J, Control(f, riesz_map="L2")).derivative( + apply_riesz=True + ) + dJdu_default_L2 = ReducedFunctional(J, Control(f)).derivative( + apply_riesz=True ) + assert isinstance(dJdu_none, Cofunction) and isinstance(dJdu_function_l2, Function) + assert isinstance(dJdu_H1, Function) and isinstance(dJdu_default_L2, Function) + assert isinstance(dJdu_L2, Function) + assert np.allclose(dJdu_none.dat.data, dJdu_cofunction.dat.data) + assert np.allclose(dJdu_l2.dat.data, dJdu_function_l2.dat.data) + assert np.allclose(dJdu_H1.dat.data, dJdu_function_H1.dat.data) + assert np.allclose(dJdu_default_L2.dat.data, dJdu_function_L2.dat.data) + assert np.allclose(dJdu_L2.dat.data, dJdu_function_L2.dat.data) @pytest.mark.skipcomplex @@ -998,7 +1003,7 @@ def test_cofunction_assign_functional(): cof2.assign(cof) # Test is checking that this is taped. J = assemble(action(cof2, f2)) Jhat = ReducedFunctional(J, Control(f)) - assert np.allclose(float(Jhat.derivative()), 1.0) + assert np.allclose(float(Jhat.derivative(apply_riesz=True)), 1.0) f.assign(2.0) assert np.allclose(Jhat(f), 2.0)