Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reverse-over-forward AD #3681

Open
wants to merge 29 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9e84696
Reverse-over-forward AD: AssembleBlock
jrmaddison Jul 10, 2024
bbbe841
In-place assignment
jrmaddison Jul 10, 2024
18a5000
Remove restored_output, now handled within pyadjoint
jrmaddison Jul 10, 2024
73c3051
Reverse-over-forward AD: ConstantAssignBlock
jrmaddison Jul 10, 2024
ceab69c
Rename variables
jrmaddison Jul 10, 2024
e42dd5e
Reverse-over-forward AD: FunctionAssignBlock. Minor test edits.
jrmaddison Jul 10, 2024
26865be
Reverse-over-forward AD: GenericSolveBlock
jrmaddison Jul 10, 2024
b9fe732
Reverse-over-forward AD: SupermeshProjectBlock
jrmaddison Jul 10, 2024
7dbd380
Fix
jrmaddison Jul 10, 2024
4c74bd1
flake8
jrmaddison Jul 10, 2024
f10d7bb
Reverse-over-forward AD: FunctionAssignBlock, self.expr==None case
jrmaddison Jul 10, 2024
a264a87
Reverse-over-forward AD: SubfunctionBlock and FunctionMergeBlock
jrmaddison Jul 10, 2024
6774688
Expand annotation in tests
jrmaddison Jul 10, 2024
e0666e5
Reverse-over-forward AD: DirichletBCBlock
jrmaddison Jul 10, 2024
c20303a
Fix names
jrmaddison Jul 10, 2024
8d03402
flake8
jrmaddison Jul 10, 2024
f48062b
Test case where inputs and outputs are different variable versions
jrmaddison Jul 11, 2024
386f1f9
Fix test_project_overwrite
jrmaddison Jul 11, 2024
468de37
Import fix
jrmaddison Jul 11, 2024
f37e642
Add TLM Taylor verification to test_dirichletbc
jrmaddison Jul 12, 2024
f5023c6
Reverse-over-forward: AssembleBlock bugfix
jrmaddison Jul 12, 2024
bf343ab
Defensive copy to handle assemble(Cofunction) case
jrmaddison Jul 12, 2024
d494d0c
ConstantAssignBlock.solve_tlm tidying
jrmaddison Jul 12, 2024
0dc5269
DirichletBCBlock.solve_tlm tidying
jrmaddison Jul 12, 2024
00ddc36
Handle ZeroBaseForm case
jrmaddison Jul 12, 2024
3688343
Test improvements
jrmaddison Jul 12, 2024
5dec19a
Test improvements
jrmaddison Jul 12, 2024
ab80fe6
DROP BEFORE MERGE
jrmaddison Jul 11, 2024
4bb4e3b
Fix context manager
jrmaddison Jul 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ jobs:
--install defcon \
--install gadopt \
--install asQ \
--package-branch pyadjoint jrmaddison/reverse_over_forward \
|| (cat firedrake-install.log && /bin/false)
- name: Install test dependencies
run: |
Expand Down
4 changes: 3 additions & 1 deletion firedrake/adjoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

from pyadjoint.tape import Tape, set_working_tape, get_working_tape, \
pause_annotation, continue_annotation, \
stop_annotating, annotate_tape # noqa F401
stop_annotating, annotate_tape, \
pause_reverse_over_forward, continue_reverse_over_forward, \
stop_reverse_over_forward # noqa F401
from pyadjoint.reduced_functional import ReducedFunctional # noqa F401
from firedrake.adjoint_utils.checkpointing import \
enable_disk_checkpointing, pause_disk_checkpointing, \
Expand Down
28 changes: 28 additions & 0 deletions firedrake/adjoint_utils/blocks/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,34 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
dform = firedrake.assemble(dform)
return dform

def solve_tlm(self):
x, = self.get_outputs()
form = self.form

tlm_rhs = 0
for block_variable in self.get_dependencies():
dep = block_variable.output
tlm_dep = block_variable.tlm_value
if tlm_dep is not None:
if isinstance(dep, firedrake.MeshGeometry):
dep = firedrake.SpatialCoordinate(dep)
tlm_rhs = tlm_rhs + firedrake.derivative(
form, dep, tlm_dep)
else:
tlm_rhs = tlm_rhs + firedrake.action(
firedrake.derivative(form, dep), tlm_dep)

x.tlm_value = None
if isinstance(tlm_rhs, int) and tlm_rhs == 0:
return
tlm_rhs = ufl.algorithms.expand_derivatives(tlm_rhs)
if isinstance(tlm_rhs, ufl.ZeroBaseForm) or (isinstance(tlm_rhs, ufl.Form) and tlm_rhs.empty()):
return
tlm_rhs = firedrake.assemble(tlm_rhs)
if tlm_rhs in {dep.output for dep in self.get_dependencies()}:
tlm_rhs = tlm_rhs.copy(deepcopy=True)
x.tlm_value = tlm_rhs

def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs,
relevant_dependencies):
return self.prepare_evaluate_adj(inputs, adj_inputs,
Expand Down
18 changes: 18 additions & 0 deletions firedrake/adjoint_utils/blocks/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy

from pyadjoint.reduced_functional_numpy import gather
import firedrake
from .block_utils import isconstant


Expand Down Expand Up @@ -70,6 +71,23 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
values = values.values()
return constant_from_values(block_variable.output, values)

def solve_tlm(self):
if self.assigned_list:
# Not reachable?
raise NotImplementedError

x, = self.get_outputs()
dep, = self.get_dependencies()
if dep.tlm_value is None:
x.tlm_value = None
else:
if len(x.output.ufl_shape) == 0:
x.tlm_value = firedrake.Constant(0.0)
else:
x.tlm_value = firedrake.Constant(
numpy.reshape(numpy.zeros_like(x.output.values()), x.output.ufl_shape))
x.tlm_value.assign(dep.tlm_value)

def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs,
relevant_dependencies):
return self.prepare_evaluate_adj(inputs, hessian_inputs,
Expand Down
8 changes: 8 additions & 0 deletions firedrake/adjoint_utils/blocks/dirichlet_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,14 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
m = bc.reconstruct(g=tlm_input)
return m

def solve_tlm(self):
x, = self.get_outputs()
dep = x.output._function_arg
if isinstance(dep, OverloadedType):
x.tlm_value = x.output.reconstruct(g=dep.block_variable.tlm_value)
else:
x.tlm_value = None

def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
block_variable, idx, relevant_dependencies,
prepared=None):
Expand Down
46 changes: 46 additions & 0 deletions firedrake/adjoint_utils/blocks/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,32 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,

return dudm

def solve_tlm(self):
x, = self.get_outputs()
expr = self.expr

if expr is None:
other, = self.get_dependencies()
if other.tlm_value is None:
x.tlm_value = None
else:
x.tlm_value = firedrake.Function(x.output.function_space()).assign(other.tlm_value)
else:
tlm_rhs = 0
for block_variable in self.get_dependencies():
dep = block_variable.output
tlm_dep = block_variable.tlm_value
if tlm_dep is not None:
tlm_rhs = tlm_rhs + ufl.derivative(expr, dep, tlm_dep)

x.tlm_value = None
if isinstance(tlm_rhs, int) and tlm_rhs == 0:
return
tlm_rhs = ufl.algorithms.expand_derivatives(tlm_rhs)
if isinstance(tlm_rhs, ufl.constantvalue.Zero):
return
x.tlm_value = firedrake.Function(x.output.function_space()).assign(tlm_rhs)

def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs,
relevant_dependencies):
return self.prepare_evaluate_adj(inputs, hessian_inputs,
Expand Down Expand Up @@ -211,6 +237,15 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
prepared=None):
return firedrake.Function.sub(tlm_inputs[0], self.idx)

def solve_tlm(self):
x, = self.get_outputs()
dep, = self.get_dependencies()
tlm_dep = dep.tlm_value
if tlm_dep is None:
x.tlm_value = None
else:
x.tlm_value = firedrake.Function(x.output.function_space()).assign(tlm_dep.sub(self.idx))

def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
block_variable, idx,
relevant_dependencies, prepared=None):
Expand Down Expand Up @@ -253,6 +288,17 @@ def evaluate_tlm(self):
type(output.output).assign(f.sub(self.idx), tlm_input)
)

def solve_tlm(self):
x, = self.get_outputs()
tlm_dep = self.get_dependencies()[0].tlm_value
if tlm_dep is None:
if x.tlm_value is not None:
x.tlm_value.sub(self.idx).zero()
else:
if x.tlm_value is None:
x.tlm_value = type(x.output)(x.output.function_space())
x.tlm_value.sub(self.idx).assign(tlm_dep)

def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
block_variable, idx,
relevant_dependencies, prepared=None):
Expand Down
60 changes: 60 additions & 0 deletions firedrake/adjoint_utils/blocks/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,56 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
dFdm, dudm, bcs
)

def solve_tlm(self):
x, = self.get_outputs()
if self.linear:
tmp_x = firedrake.Function(x.output.function_space())
replace_map = {tmp_x: x.output}
form = firedrake.action(self.lhs, tmp_x) - self.rhs
else:
replace_map = None
form = self.lhs

tlm_rhs = 0
tlm_bcs = []
for block_variable in self.get_dependencies():
dep = block_variable.output
if dep == x.output and not self.linear:
continue
tlm_dep = block_variable.tlm_value
if isinstance(dep, firedrake.DirichletBC):
if tlm_dep is None:
tlm_bcs.append(dep.reconstruct(g=0))
else:
tlm_bcs.append(tlm_dep)
elif tlm_dep is not None:
if isinstance(dep, firedrake.MeshGeometry):
dep = firedrake.SpatialCoordinate(dep)
tlm_rhs = tlm_rhs - firedrake.derivative(
form, dep, tlm_dep)
else:
tlm_rhs = tlm_rhs - firedrake.action(
firedrake.derivative(form, dep), tlm_dep)

if isinstance(tlm_rhs, int) and tlm_rhs == 0:
tlm_rhs = firedrake.Cofunction(x.output.function_space().dual())
else:
tlm_rhs = ufl.algorithms.expand_derivatives(tlm_rhs)
if isinstance(tlm_rhs, ufl.ZeroBaseForm) or (isinstance(tlm_rhs, ufl.Form) and tlm_rhs.empty()):
tlm_rhs = firedrake.Cofunction(x.output.function_space().dual())

if self.linear:
J = self.lhs
else:
J = firedrake.derivative(form, x.output, firedrake.TrialFunction(x.output.function_space()))

if replace_map is not None:
J = ufl.replace(J, replace_map)
tlm_rhs = ufl.replace(tlm_rhs, replace_map)
x.tlm_value = firedrake.Function(x.output.function_space())
firedrake.solve(J == tlm_rhs, x.tlm_value, tlm_bcs, *self.forward_args,
**self.forward_kwargs)

def _assemble_and_solve_tlm_eq(self, dFdu, dFdm, dudm, bcs):
return self._assembled_solve(dFdu, dFdm, dudm, bcs)

Expand Down Expand Up @@ -798,6 +848,7 @@ def __init__(self, source, target_space, target, bcs=[], **kwargs):
mesh = target_space.mesh()
self.source_space = source.function_space()
self.target_space = target_space
self._kwargs = dict(kwargs)
self.projector = firedrake.Projector(source, target_space, **kwargs)

# Assemble mixed mass matrix
Expand Down Expand Up @@ -878,6 +929,15 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
prepared)
return dJdm

def solve_tlm(self):
x, = self.get_outputs()
dep, = self.get_dependencies()
if dep.tlm_value is None:
x.tlm_value = None
else:
x.tlm_value = firedrake.Function(x.output.function_space())
firedrake.project(dep.tlm_value, x.tlm_value, **self._kwargs)

def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
block_variable, idx,
relevant_dependencies, prepared=None):
Expand Down
3 changes: 3 additions & 0 deletions firedrake/adjoint_utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def _ad_assign_numpy(dst, src, offset):
def _ad_to_list(m):
return m.dat.data_ro.reshape(-1).tolist()

def _ad_assign(self, other):
self.assign(other)

def _ad_copy(self):
return self._constant_from_values()

Expand Down
3 changes: 3 additions & 0 deletions firedrake/adjoint_utils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,9 @@ def _ad_to_list(m):

return m_a.tolist()

def _ad_assign(self, other):
self.assign(other)

def _ad_copy(self):
from firedrake import Function

Expand Down
2 changes: 1 addition & 1 deletion requirements-git.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ git+https://github.com/firedrakeproject/fiat.git#egg=fiat
git+https://github.com/FInAT/FInAT.git#egg=finat
git+https://github.com/firedrakeproject/tsfc.git#egg=tsfc
git+https://github.com/OP2/PyOP2.git#egg=pyop2
git+https://github.com/dolfin-adjoint/pyadjoint.git#egg=pyadjoint
git+https://github.com/jrmaddison/pyadjoint.git#egg=pyadjoint
git+https://github.com/firedrakeproject/petsc.git@firedrake#egg=petsc
git+https://github.com/firedrakeproject/pytest-mpi.git@main#egg=pytest-mpi
Loading
Loading