diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4f61ab2d9b..f3c791e05f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -82,6 +82,7 @@ jobs: --install defcon \ --install gadopt \ --install asQ \ + --package-branch pyadjoint jrmaddison/reverse_over_forward \ || (cat firedrake-install.log && /bin/false) - name: Install test dependencies run: | diff --git a/firedrake/adjoint/__init__.py b/firedrake/adjoint/__init__.py index 8373a4285a..c3f4c8a00a 100644 --- a/firedrake/adjoint/__init__.py +++ b/firedrake/adjoint/__init__.py @@ -19,7 +19,9 @@ from pyadjoint.tape import Tape, set_working_tape, get_working_tape, \ pause_annotation, continue_annotation, \ - stop_annotating, annotate_tape # noqa F401 + stop_annotating, annotate_tape, \ + pause_reverse_over_forward, continue_reverse_over_forward, \ + stop_reverse_over_forward # noqa F401 from pyadjoint.reduced_functional import ReducedFunctional # noqa F401 from firedrake.adjoint_utils.checkpointing import \ enable_disk_checkpointing, pause_disk_checkpointing, \ diff --git a/firedrake/adjoint_utils/blocks/assembly.py b/firedrake/adjoint_utils/blocks/assembly.py index 8ba790ec74..82568b6848 100644 --- a/firedrake/adjoint_utils/blocks/assembly.py +++ b/firedrake/adjoint_utils/blocks/assembly.py @@ -145,6 +145,34 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, dform = firedrake.assemble(dform) return dform + def solve_tlm(self): + x, = self.get_outputs() + form = self.form + + tlm_rhs = 0 + for block_variable in self.get_dependencies(): + dep = block_variable.output + tlm_dep = block_variable.tlm_value + if tlm_dep is not None: + if isinstance(dep, firedrake.MeshGeometry): + dep = firedrake.SpatialCoordinate(dep) + tlm_rhs = tlm_rhs + firedrake.derivative( + form, dep, tlm_dep) + else: + tlm_rhs = tlm_rhs + firedrake.action( + firedrake.derivative(form, dep), tlm_dep) + + x.tlm_value = None + if isinstance(tlm_rhs, int) and tlm_rhs == 0: + return + tlm_rhs = ufl.algorithms.expand_derivatives(tlm_rhs) + if isinstance(tlm_rhs, ufl.ZeroBaseForm) or (isinstance(tlm_rhs, ufl.Form) and tlm_rhs.empty()): + return + tlm_rhs = firedrake.assemble(tlm_rhs) + if tlm_rhs in {dep.output for dep in self.get_dependencies()}: + tlm_rhs = tlm_rhs.copy(deepcopy=True) + x.tlm_value = tlm_rhs + def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, relevant_dependencies): return self.prepare_evaluate_adj(inputs, adj_inputs, diff --git a/firedrake/adjoint_utils/blocks/constant.py b/firedrake/adjoint_utils/blocks/constant.py index d5e580428a..40178916b7 100644 --- a/firedrake/adjoint_utils/blocks/constant.py +++ b/firedrake/adjoint_utils/blocks/constant.py @@ -2,6 +2,7 @@ import numpy from pyadjoint.reduced_functional_numpy import gather +import firedrake from .block_utils import isconstant @@ -70,6 +71,23 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, values = values.values() return constant_from_values(block_variable.output, values) + def solve_tlm(self): + if self.assigned_list: + # Not reachable? + raise NotImplementedError + + x, = self.get_outputs() + dep, = self.get_dependencies() + if dep.tlm_value is None: + x.tlm_value = None + else: + if len(x.output.ufl_shape) == 0: + x.tlm_value = firedrake.Constant(0.0) + else: + x.tlm_value = firedrake.Constant( + numpy.reshape(numpy.zeros_like(x.output.values()), x.output.ufl_shape)) + x.tlm_value.assign(dep.tlm_value) + def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, relevant_dependencies): return self.prepare_evaluate_adj(inputs, hessian_inputs, diff --git a/firedrake/adjoint_utils/blocks/dirichlet_bc.py b/firedrake/adjoint_utils/blocks/dirichlet_bc.py index 7a4784812c..9ee708c5b6 100644 --- a/firedrake/adjoint_utils/blocks/dirichlet_bc.py +++ b/firedrake/adjoint_utils/blocks/dirichlet_bc.py @@ -121,6 +121,14 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, m = bc.reconstruct(g=tlm_input) return m + def solve_tlm(self): + x, = self.get_outputs() + dep = x.output._function_arg + if isinstance(dep, OverloadedType): + x.tlm_value = x.output.reconstruct(g=dep.block_variable.tlm_value) + else: + x.tlm_value = None + def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None): diff --git a/firedrake/adjoint_utils/blocks/function.py b/firedrake/adjoint_utils/blocks/function.py index fc5be8486a..30c4941448 100644 --- a/firedrake/adjoint_utils/blocks/function.py +++ b/firedrake/adjoint_utils/blocks/function.py @@ -129,6 +129,32 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, return dudm + def solve_tlm(self): + x, = self.get_outputs() + expr = self.expr + + if expr is None: + other, = self.get_dependencies() + if other.tlm_value is None: + x.tlm_value = None + else: + x.tlm_value = firedrake.Function(x.output.function_space()).assign(other.tlm_value) + else: + tlm_rhs = 0 + for block_variable in self.get_dependencies(): + dep = block_variable.output + tlm_dep = block_variable.tlm_value + if tlm_dep is not None: + tlm_rhs = tlm_rhs + ufl.derivative(expr, dep, tlm_dep) + + x.tlm_value = None + if isinstance(tlm_rhs, int) and tlm_rhs == 0: + return + tlm_rhs = ufl.algorithms.expand_derivatives(tlm_rhs) + if isinstance(tlm_rhs, ufl.constantvalue.Zero): + return + x.tlm_value = firedrake.Function(x.output.function_space()).assign(tlm_rhs) + def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, relevant_dependencies): return self.prepare_evaluate_adj(inputs, hessian_inputs, @@ -211,6 +237,15 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None): return firedrake.Function.sub(tlm_inputs[0], self.idx) + def solve_tlm(self): + x, = self.get_outputs() + dep, = self.get_dependencies() + tlm_dep = dep.tlm_value + if tlm_dep is None: + x.tlm_value = None + else: + x.tlm_value = firedrake.Function(x.output.function_space()).assign(tlm_dep.sub(self.idx)) + def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None): @@ -253,6 +288,17 @@ def evaluate_tlm(self): type(output.output).assign(f.sub(self.idx), tlm_input) ) + def solve_tlm(self): + x, = self.get_outputs() + tlm_dep = self.get_dependencies()[0].tlm_value + if tlm_dep is None: + if x.tlm_value is not None: + x.tlm_value.sub(self.idx).zero() + else: + if x.tlm_value is None: + x.tlm_value = type(x.output)(x.output.function_space()) + x.tlm_value.sub(self.idx).assign(tlm_dep) + def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None): diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index 7a57e883c3..72c3056dae 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -317,6 +317,56 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, dFdm, dudm, bcs ) + def solve_tlm(self): + x, = self.get_outputs() + if self.linear: + tmp_x = firedrake.Function(x.output.function_space()) + replace_map = {tmp_x: x.output} + form = firedrake.action(self.lhs, tmp_x) - self.rhs + else: + replace_map = None + form = self.lhs + + tlm_rhs = 0 + tlm_bcs = [] + for block_variable in self.get_dependencies(): + dep = block_variable.output + if dep == x.output and not self.linear: + continue + tlm_dep = block_variable.tlm_value + if isinstance(dep, firedrake.DirichletBC): + if tlm_dep is None: + tlm_bcs.append(dep.reconstruct(g=0)) + else: + tlm_bcs.append(tlm_dep) + elif tlm_dep is not None: + if isinstance(dep, firedrake.MeshGeometry): + dep = firedrake.SpatialCoordinate(dep) + tlm_rhs = tlm_rhs - firedrake.derivative( + form, dep, tlm_dep) + else: + tlm_rhs = tlm_rhs - firedrake.action( + firedrake.derivative(form, dep), tlm_dep) + + if isinstance(tlm_rhs, int) and tlm_rhs == 0: + tlm_rhs = firedrake.Cofunction(x.output.function_space().dual()) + else: + tlm_rhs = ufl.algorithms.expand_derivatives(tlm_rhs) + if isinstance(tlm_rhs, ufl.ZeroBaseForm) or (isinstance(tlm_rhs, ufl.Form) and tlm_rhs.empty()): + tlm_rhs = firedrake.Cofunction(x.output.function_space().dual()) + + if self.linear: + J = self.lhs + else: + J = firedrake.derivative(form, x.output, firedrake.TrialFunction(x.output.function_space())) + + if replace_map is not None: + J = ufl.replace(J, replace_map) + tlm_rhs = ufl.replace(tlm_rhs, replace_map) + x.tlm_value = firedrake.Function(x.output.function_space()) + firedrake.solve(J == tlm_rhs, x.tlm_value, tlm_bcs, *self.forward_args, + **self.forward_kwargs) + def _assemble_and_solve_tlm_eq(self, dFdu, dFdm, dudm, bcs): return self._assembled_solve(dFdu, dFdm, dudm, bcs) @@ -798,6 +848,7 @@ def __init__(self, source, target_space, target, bcs=[], **kwargs): mesh = target_space.mesh() self.source_space = source.function_space() self.target_space = target_space + self._kwargs = dict(kwargs) self.projector = firedrake.Projector(source, target_space, **kwargs) # Assemble mixed mass matrix @@ -878,6 +929,15 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared) return dJdm + def solve_tlm(self): + x, = self.get_outputs() + dep, = self.get_dependencies() + if dep.tlm_value is None: + x.tlm_value = None + else: + x.tlm_value = firedrake.Function(x.output.function_space()) + firedrake.project(dep.tlm_value, x.tlm_value, **self._kwargs) + def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None): diff --git a/firedrake/adjoint_utils/constant.py b/firedrake/adjoint_utils/constant.py index fed5978a6d..c9606981d4 100644 --- a/firedrake/adjoint_utils/constant.py +++ b/firedrake/adjoint_utils/constant.py @@ -104,6 +104,9 @@ def _ad_assign_numpy(dst, src, offset): def _ad_to_list(m): return m.dat.data_ro.reshape(-1).tolist() + def _ad_assign(self, other): + self.assign(other) + def _ad_copy(self): return self._constant_from_values() diff --git a/firedrake/adjoint_utils/function.py b/firedrake/adjoint_utils/function.py index d56438fec7..5492ad04a4 100644 --- a/firedrake/adjoint_utils/function.py +++ b/firedrake/adjoint_utils/function.py @@ -352,6 +352,9 @@ def _ad_to_list(m): return m_a.tolist() + def _ad_assign(self, other): + self.assign(other) + def _ad_copy(self): from firedrake import Function diff --git a/requirements-git.txt b/requirements-git.txt index 8bf05aad21..db584f08d8 100644 --- a/requirements-git.txt +++ b/requirements-git.txt @@ -3,6 +3,6 @@ git+https://github.com/firedrakeproject/fiat.git#egg=fiat git+https://github.com/FInAT/FInAT.git#egg=finat git+https://github.com/firedrakeproject/tsfc.git#egg=tsfc git+https://github.com/OP2/PyOP2.git#egg=pyop2 -git+https://github.com/dolfin-adjoint/pyadjoint.git#egg=pyadjoint +git+https://github.com/jrmaddison/pyadjoint.git#egg=pyadjoint git+https://github.com/firedrakeproject/petsc.git@firedrake#egg=petsc git+https://github.com/firedrakeproject/pytest-mpi.git@main#egg=pytest-mpi diff --git a/tests/regression/test_adjoint_reverse_over_forward.py b/tests/regression/test_adjoint_reverse_over_forward.py new file mode 100644 index 0000000000..00c7d4530e --- /dev/null +++ b/tests/regression/test_adjoint_reverse_over_forward.py @@ -0,0 +1,287 @@ +from contextlib import contextmanager +import numpy as np +import pytest + +from firedrake import * +from firedrake.adjoint import * +from firedrake.__future__ import * + + +@pytest.fixture(autouse=True) +def _(): + get_working_tape().clear_tape() + pause_annotation() + pause_reverse_over_forward() + yield + get_working_tape().clear_tape() + pause_annotation() + pause_reverse_over_forward() + + +@contextmanager +def reverse_over_forward(): + continue_annotation() + continue_reverse_over_forward() + try: + yield + finally: + pause_annotation() + pause_reverse_over_forward() + + +@pytest.mark.skipcomplex +def test_assembly(): + mesh = UnitIntervalMesh(10) + X = SpatialCoordinate(mesh) + space = FunctionSpace(mesh, "Lagrange", 1) + test = TestFunction(space) + + with reverse_over_forward(): + u = Function(space, name="u").interpolate(X[0] - 0.5) + u_ref = u.copy(deepcopy=True) + zeta = Function(space, name="zeta").interpolate(X[0]) + u.block_variable.tlm_value = zeta.copy(deepcopy=True) + + J = assemble((u ** 3) * dx) + + _ = compute_gradient(J.block_variable.tlm_value, Control(u)) + adj_value = u.block_variable.adj_value + assert np.allclose( + adj_value.dat.data_ro, + assemble(6 * inner(u_ref * zeta, test) * dx).dat.data_ro) + + +@pytest.mark.skipcomplex +def test_constant_assignment(): + with reverse_over_forward(): + a = Constant(2.5) + a.block_variable.tlm_value = Constant(-2.0) + + b = Constant(0.0).assign(a) + + assert float(b.block_variable.tlm_value) == -2.0 + + # Minimal test that the TLM operation is on the tape + _ = compute_gradient(b.block_variable.tlm_value, Control(a.block_variable.tlm_value)) + adj_value = a.block_variable.tlm_value.block_variable.adj_value + assert float(adj_value) == 1.0 + + +@pytest.mark.skipcomplex +def test_function_assignment(): + mesh = UnitIntervalMesh(10) + X = SpatialCoordinate(mesh) + space = FunctionSpace(mesh, "Lagrange", 1) + test = TestFunction(space) + + with reverse_over_forward(): + u = Function(space, name="u").interpolate(X[0] - 0.5) + u_ref = u.copy(deepcopy=True) + zeta = Function(space, name="zeta").interpolate(X[0]) + u.block_variable.tlm_value = zeta.copy(deepcopy=True) + + v = Function(space, name="v").assign(u) + J = assemble((v ** 3) * dx) + + _ = compute_gradient(J.block_variable.tlm_value, Control(u)) + adj_value = u.block_variable.adj_value + assert np.allclose( + adj_value.dat.data_ro, + assemble(6 * inner(u_ref * zeta, test) * dx).dat.data_ro) + + +@pytest.mark.skipcomplex +def test_function_assignment_expr(): + mesh = UnitIntervalMesh(10) + X = SpatialCoordinate(mesh) + space = FunctionSpace(mesh, "Lagrange", 1) + test = TestFunction(space) + + with reverse_over_forward(): + u = Function(space, name="u").interpolate(X[0] - 0.5) + zeta = Function(space, name="zeta").interpolate(X[0]) + u_ref = u.copy(deepcopy=True) + u.block_variable.tlm_value = zeta.copy(deepcopy=True) + + v = Function(space, name="v").assign(-3 * u) + J = assemble((v ** 3) * dx) + + _ = compute_gradient(J.block_variable.tlm_value, Control(u)) + adj_value = u.block_variable.adj_value + assert np.allclose( + adj_value.dat.data_ro, + assemble(-162 * inner(u_ref * zeta, test) * dx).dat.data_ro) + + +@pytest.mark.skipcomplex +@pytest.mark.parametrize("idx", [0, 1]) +def test_subfunction(idx): + mesh = UnitIntervalMesh(10) + X = SpatialCoordinate(mesh) + space = FunctionSpace(mesh, "Lagrange", 1) * FunctionSpace(mesh, "Lagrange", 1) + test = TestFunction(space) + + with reverse_over_forward(): + u = Function(space, name="u") + u.sub(idx).interpolate(X[0] - 0.5) + u_ref = u.copy(deepcopy=True) + zeta = Function(space, name="zeta") + zeta.sub(idx).interpolate(X[0]) + u.block_variable.tlm_value = zeta.copy(deepcopy=True) + + v = Function(space, name="v") + v.sub(idx).assign(u.sub(idx)) + J = assemble((u.sub(idx) ** 2) * v.sub(idx) * dx) + + _ = compute_gradient(J.block_variable.tlm_value, Control(u)) + adj_value = u.block_variable.adj_value + assert np.allclose( + adj_value.sub(idx).dat.data_ro, + assemble(6 * inner(u_ref[idx] * zeta[idx], test[idx]) * dx).dat.data_ro[idx]) + + +@pytest.mark.skipcomplex +def test_interpolate(): + mesh = UnitIntervalMesh(10) + X = SpatialCoordinate(mesh) + space_a = FunctionSpace(mesh, "Lagrange", 1) + space_b = FunctionSpace(mesh, "Lagrange", 2) + test_a = TestFunction(space_a) + + with reverse_over_forward(): + u = Function(space_a, name="u").interpolate(X[0] - 0.5) + u_ref = u.copy(deepcopy=True) + zeta = Function(space_a, name="zeta").interpolate(X[0]) + u.block_variable.tlm_value = zeta.copy(deepcopy=True) + + v = Function(space_b, name="v").interpolate(u) + J = assemble(v ** 3 * dx) + + _ = compute_gradient(J.block_variable.tlm_value, Control(u)) + adj_value = u.block_variable.adj_value + assert np.allclose( + adj_value.dat.data_ro, + assemble(6 * inner(u_ref * zeta, test_a) * dx).dat.data_ro) + + +@pytest.mark.skipcomplex +def test_interpolate_expr(): + mesh = UnitIntervalMesh(10) + X = SpatialCoordinate(mesh) + space_a = FunctionSpace(mesh, "Lagrange", 1) + space_b = FunctionSpace(mesh, "Lagrange", 2) + test_a = TestFunction(space_a) + + with reverse_over_forward(): + u = Function(space_a, name="u").interpolate(X[0] - 0.5) + u_ref = u.copy(deepcopy=True) + zeta = Function(space_a, name="zeta").interpolate(X[0]) + u.block_variable.tlm_value = zeta.copy(deepcopy=True) + + v = Function(space_b, name="v").interpolate(-3 * u) + J = assemble(v ** 3 * dx) + + _ = compute_gradient(J.block_variable.tlm_value, Control(u)) + adj_value = u.block_variable.adj_value + assert np.allclose( + adj_value.dat.data_ro, + assemble(-162 * inner(u_ref * zeta, test_a) * dx).dat.data_ro) + + +@pytest.mark.skipcomplex +def test_project(): + mesh = UnitIntervalMesh(10) + X = SpatialCoordinate(mesh) + space_a = FunctionSpace(mesh, "Lagrange", 1) + space_b = FunctionSpace(mesh, "Discontinuous Lagrange", 0) + test_a = TestFunction(space_a) + + with reverse_over_forward(): + u = Function(space_a, name="u").interpolate(X[0] - 0.5) + u_ref = u.copy(deepcopy=True) + zeta = Function(space_a, name="zeta").interpolate(X[0]) + u.block_variable.tlm_value = zeta.copy(deepcopy=True) + + v = Function(space_b, name="v").project(u) + J = assemble(v ** 3 * dx) + + _ = compute_gradient(J.block_variable.tlm_value, Control(u)) + adj_value = u.block_variable.adj_value + assert np.allclose( + adj_value.dat.data_ro, + assemble(6 * inner(Function(space_b).project(u_ref) * Function(space_b).project(zeta), test_a) * dx).dat.data_ro) + + +@pytest.mark.skipcomplex +def test_project_overwrite(): + mesh = UnitIntervalMesh(10) + X = SpatialCoordinate(mesh) + space = FunctionSpace(mesh, "Lagrange", 1) + test = TestFunction(space) + + with reverse_over_forward(): + u = Function(space, name="u").interpolate(X[0] - 0.5) + u_ref = u.copy(deepcopy=True) + zeta = Function(space, name="zeta").interpolate(X[0]) + u.block_variable.tlm_value = zeta.copy(deepcopy=True) + + v = Function(space).project(-2 * u) + w = Function(space).assign(u) + w.project(-2 * w) + assert np.allclose(w.dat.data_ro, v.dat.data_ro) + assert np.allclose(v.block_variable.tlm_value.dat.data_ro, + -2 * zeta.dat.data_ro) + assert np.allclose(w.block_variable.tlm_value.dat.data_ro, + -2 * zeta.dat.data_ro) + J = assemble(w ** 3 * dx) + + _ = compute_gradient(J.block_variable.tlm_value, Control(u)) + adj_value = u.block_variable.adj_value + assert np.allclose( + adj_value.dat.data_ro, + assemble(-48 * inner(u_ref * zeta, test) * dx).dat.data_ro) + + +@pytest.mark.skipcomplex +def test_supermesh_project(): + mesh_a = UnitSquareMesh(10, 10) + mesh_b = UnitSquareMesh(5, 20) + X_a = SpatialCoordinate(mesh_a) + space_a = FunctionSpace(mesh_a, "Lagrange", 1) + space_b = FunctionSpace(mesh_b, "Discontinuous Lagrange", 0) + test_a = TestFunction(space_a) + + with reverse_over_forward(): + u = Function(space_a, name="u").interpolate(X_a[0] - 0.5) + zeta = Function(space_a, name="zeta").interpolate(X_a[0]) + u.block_variable.tlm_value = zeta.copy(deepcopy=True) + + v = Function(space_b, name="v").project(u) + J = assemble(v ** 2 * dx) + + _ = compute_gradient(J.block_variable.tlm_value, Control(u)) + adj_value = u.block_variable.adj_value + assert np.allclose( + adj_value.dat.data_ro, + assemble(2 * inner(Function(space_a).project(Function(space_b).project(zeta)), test_a) * dx).dat.data_ro) + + +@pytest.mark.skipcomplex +def test_dirichletbc(): + mesh = UnitIntervalMesh(10) + X = SpatialCoordinate(mesh) + space = FunctionSpace(mesh, "Lagrange", 1) + + with reverse_over_forward(): + u = Function(space, name="u").interpolate(X[0] - 0.5) + zeta = Function(space, name="zeta").interpolate(X[0]) + u.block_variable.tlm_value = zeta.copy(deepcopy=True) + bc = DirichletBC(space, u, "on_boundary") + + v = project(Constant(0.0), space, bcs=bc) + J = assemble(v ** 3 * dx) + + J_hat = ReducedFunctional(J, Control(u)) + assert taylor_test(J_hat, u, zeta, dJdm=J.block_variable.tlm_value) > 1.9 + J_hat = ReducedFunctional(J.block_variable.tlm_value, Control(u)) + assert taylor_test(J_hat, u, Function(space).interpolate(X[0] * X[0])) > 1.9