Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dham/abstract reduced functional #3941

Draft
wants to merge 23 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ jobs:
--install defcon \
--install gadopt \
--install asQ \
--package-branch pyadjoint dham/abstract_reduced_functional
|| (cat firedrake-install.log && /bin/false)
- name: Install test dependencies
run: |
Expand Down
19 changes: 15 additions & 4 deletions firedrake/adjoint_utils/blocks/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,15 @@ def _init_solver_parameters(self, args, kwargs):
self.assemble_kwargs = {}

def __str__(self):
return "solve({} = {})".format(ufl2unicode(self.lhs),
ufl2unicode(self.rhs))
try:
lhs_string = ufl2unicode(self.lhs)
except AttributeError:
lhs_string = str(self.lhs)
try:
rhs_string = ufl2unicode(self.rhs)
except AttributeError:
rhs_string = str(self.rhs)
return "solve({} = {})".format(lhs_string, rhs_string)

def _create_F_form(self):
# Process the equation forms, replacing values with checkpoints,
Expand Down Expand Up @@ -756,7 +763,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
c = block_variable.output
c_rep = block_variable.saved_output

if isinstance(c, firedrake.Function):
if isinstance(c, (firedrake.Function, firedrake.Cofunction)):
trial_function = firedrake.TrialFunction(c.function_space())
elif isinstance(c, firedrake.Constant):
mesh = F_form.ufl_domain()
Expand Down Expand Up @@ -793,7 +800,11 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
replace_map[self.func] = self.get_outputs()[0].saved_output
dFdm = replace(dFdm, replace_map)

dFdm = dFdm * adj_sol
if isinstance(dFdm, firedrake.Argument):
# Corner case. Should be fixed more permanently upstream in UFL.
dFdm = ufl.Action(dFdm, adj_sol)
else:
dFdm = dFdm * adj_sol
dFdm = firedrake.assemble(dFdm, **self.assemble_kwargs)

return dFdm
Expand Down
91 changes: 21 additions & 70 deletions firedrake/adjoint_utils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import ufl
from ufl.domain import extract_unique_domain
from pyadjoint.overloaded_type import create_overloaded_object, FloatingType
from pyadjoint.tape import annotate_tape, stop_annotating, get_working_tape, no_annotations
from pyadjoint.tape import annotate_tape, stop_annotating, get_working_tape
from firedrake.adjoint_utils.blocks import FunctionAssignBlock, ProjectBlock, SubfunctionBlock, FunctionMergeBlock, SupermeshProjectBlock
import firedrake
from .checkpointing import disk_checkpointing, CheckpointFunction, \
Expand Down Expand Up @@ -220,72 +220,15 @@ def _ad_create_checkpoint(self):
else:
return self.copy(deepcopy=True)

def _ad_convert_riesz(self, value, options=None):
from firedrake import Function, Cofunction

options = {} if options is None else options
riesz_representation = options.get("riesz_representation", "L2")
solver_options = options.get("solver_options", {})
V = options.get("function_space", self.function_space())
if value == 0.:
# In adjoint-based differentiation, value == 0. arises only when
# the functional is independent on the control variable.
return Function(V)

if not isinstance(value, (Cofunction, Function)):
raise TypeError("Expected a Cofunction or a Function")

if riesz_representation == "l2":
return Function(V, val=value.dat)

elif riesz_representation in ("L2", "H1"):
if not isinstance(value, Cofunction):
raise TypeError("Expected a Cofunction")

ret = Function(V)
a = self._define_riesz_map_form(riesz_representation, V)
firedrake.solve(a == value, ret, **solver_options)
return ret

elif callable(riesz_representation):
return riesz_representation(value)

else:
raise ValueError(
"Unknown Riesz representation %s" % riesz_representation)
def _ad_convert_riesz(self, value, riesz_map=None):
return value.riesz_representation(riesz_map=riesz_map or "L2")

def _define_riesz_map_form(self, riesz_representation, V):
from firedrake import TrialFunction, TestFunction

u = TrialFunction(V)
v = TestFunction(V)
if riesz_representation == "L2":
a = firedrake.inner(u, v)*firedrake.dx

elif riesz_representation == "H1":
a = firedrake.inner(u, v)*firedrake.dx \
+ firedrake.inner(firedrake.grad(u), firedrake.grad(v))*firedrake.dx

else:
raise NotImplementedError(
"Unknown Riesz representation %s" % riesz_representation)
return a

@no_annotations
def _ad_convert_type(self, value, options=None):
# `_ad_convert_type` is not annotated, unlike `_ad_convert_riesz`
options = {} if options is None else options.copy()
options.setdefault("riesz_representation", "L2")
if options["riesz_representation"] is None:
if value == 0.:
# In adjoint-based differentiation, value == 0. arises only when
# the functional is independent on the control variable.
V = options.get("function_space", self.function_space())
return firedrake.Cofunction(V.dual())
else:
return value
def _ad_init_zero(self, dual=False):
from firedrake import Function, Cofunction
if dual:
return Cofunction(self.function_space().dual())
else:
return self._ad_convert_riesz(value, options=options)
return Function(self.function_space())

def _ad_restore_at_checkpoint(self, checkpoint):
if isinstance(checkpoint, CheckpointBase):
Expand All @@ -294,17 +237,16 @@ def _ad_restore_at_checkpoint(self, checkpoint):
return checkpoint

def _ad_will_add_as_dependency(self):
"""Method called when the object is added as a Block dependency.

"""
"""Method called when the object is added as a Block dependency."""
with checkpoint_init_data():
super()._ad_will_add_as_dependency()

def _ad_mul(self, other):
from firedrake import Function

r = Function(self.function_space())
# `self` can be a Cofunction in which case only left multiplication with a scalar is allowed.
# `self` can be a Cofunction in which case only left multiplication
# with a scalar is allowed.
r.assign(other * self)
return r

Expand All @@ -316,7 +258,10 @@ def _ad_add(self, other):
return r

def _ad_dot(self, other, options=None):
from firedrake import assemble
from firedrake import assemble, action, Cofunction

if isinstance(other, Cofunction):
return assemble(action(other, self))

options = {} if options is None else options
riesz_representation = options.get("riesz_representation", "L2")
Expand Down Expand Up @@ -406,3 +351,9 @@ def _ad_to_petsc(self, vec=None):

def __deepcopy__(self, memodict={}):
return self.copy(deepcopy=True)


class CofunctionMixin(FunctionMixin):

def _ad_dot(self, other):
return firedrake.assemble(firedrake.action(self, other))
Loading
Loading