Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove apply_default_restrictions() #329

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/test_apply_restrictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
i,
triangle,
)
from ufl.algorithms.apply_restrictions import apply_default_restrictions, apply_restrictions
from ufl.algorithms.apply_restrictions import apply_restrictions
from ufl.algorithms.renumbering import renumber_indices
from ufl.finiteelement import FiniteElement
from ufl.pullback import identity_pullback
Expand Down Expand Up @@ -54,7 +54,7 @@ def test_apply_restrictions():
assert apply_restrictions((grad(f) + grad(g))("-")) == (grad(f)("-") + grad(g)("-"))

# x is the same from both sides but computed from one of them
assert apply_default_restrictions(x) == x("+")
assert apply_restrictions(x) == x("+")

# n on a linear mesh is opposite pointing from the other side
assert apply_restrictions(n("+")) == n("+")
Expand Down
99 changes: 26 additions & 73 deletions ufl/algorithms/apply_restrictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,20 @@
class RestrictionPropagator(MultiFunction):
"""Restriction propagator."""

def __init__(self, side=None):
def __init__(self, side=None, apply_default=True):
"""Initialise."""
MultiFunction.__init__(self)
self.current_restriction = side
self.default_restriction = "+"
self.apply_default = apply_default
# Caches for propagating the restriction with map_expr_dag
self.vcaches = {"+": {}, "-": {}}
self.rcaches = {"+": {}, "-": {}}
if self.current_restriction is None:
self._rp = {"+": RestrictionPropagator("+"), "-": RestrictionPropagator("-")}
self._rp = {
"+": RestrictionPropagator(side="+", apply_default=apply_default),
"-": RestrictionPropagator(side="-", apply_default=apply_default),
}

def restricted(self, o):
"""When hitting a restricted quantity, visit child with a separate restriction algorithm."""
Expand Down Expand Up @@ -64,9 +68,12 @@ def _require_restriction(self, o):
def _default_restricted(self, o):
"""Restrict a continuous quantity to default side if no current restriction is set."""
r = self.current_restriction
if r is None:
r = self.default_restriction
return o(r)
if r is not None:
return o(r)
if self.apply_default:
return o(self.default_restriction)
else:
return o

def _opposite(self, o):
"""Restrict a quantity to default side.
Expand Down Expand Up @@ -139,6 +146,18 @@ def reference_value(self, o):
reference_cell_volume = _ignore_restriction
reference_facet_volume = _ignore_restriction

# These are the same from either side but to compute them
# cell (or facet) data from one side must be selected:
spatial_coordinate = _default_restricted
# Depends on cell only to get to the facet:
facet_jacobian = _default_restricted
facet_jacobian_determinant = _default_restricted
facet_jacobian_inverse = _default_restricted
facet_area = _default_restricted
min_facet_edge_length = _default_restricted
max_facet_edge_length = _default_restricted
facet_origin = _default_restricted # FIXME: Is this valid for quads?

def coefficient(self, o):
"""Restrict a coefficient.
Expand Down Expand Up @@ -174,76 +193,10 @@ def facet_normal(self, o):
return self._require_restriction(o)


def apply_restrictions(expression):
def apply_restrictions(expression, apply_default=True):
"""Propagate restriction nodes to wrap differential terminals directly."""
integral_types = [
k for k in integral_type_to_measure_name.keys() if k.startswith("interior_facet")
]
rules = RestrictionPropagator()
return map_integrand_dags(rules, expression, only_integral_type=integral_types)


class DefaultRestrictionApplier(MultiFunction):
"""Default restriction applier."""

def __init__(self, side=None):
"""Initialise."""
MultiFunction.__init__(self)
self.current_restriction = side
self.default_restriction = "+"
if self.current_restriction is None:
self._rp = {"+": DefaultRestrictionApplier("+"), "-": DefaultRestrictionApplier("-")}

def terminal(self, o):
"""Apply to terminal."""
# Most terminals are unchanged
return o

# Default: Operators should reconstruct only if subtrees are not touched
operator = MultiFunction.reuse_if_untouched

def restricted(self, o):
"""Apply to restricted."""
# Don't restrict twice
return o

def derivative(self, o):
"""Apply to derivative."""
# I don't think it's safe to just apply default restriction
# to the argument of any derivative, i.e. grad(cg1_function)
# is not continuous across cells even if cg1_function is.
return o

def _default_restricted(self, o):
"""Restrict a continuous quantity to default side if no current restriction is set."""
r = self.current_restriction
if r is None:
r = self.default_restriction
return o(r)

# These are the same from either side but to compute them
# cell (or facet) data from one side must be selected:
spatial_coordinate = _default_restricted
# Depends on cell only to get to the facet:
facet_jacobian = _default_restricted
facet_jacobian_determinant = _default_restricted
facet_jacobian_inverse = _default_restricted
# facet_tangents = _default_restricted
# facet_midpoint = _default_restricted
facet_area = _default_restricted
# facet_diameter = _default_restricted
min_facet_edge_length = _default_restricted
max_facet_edge_length = _default_restricted
facet_origin = _default_restricted # FIXME: Is this valid for quads?


def apply_default_restrictions(expression):
"""Some terminals can be restricted from either side.
This applies a default restriction to such terminals if unrestricted.
"""
integral_types = [
k for k in integral_type_to_measure_name.keys() if k.startswith("interior_facet")
]
rules = DefaultRestrictionApplier()
rules = RestrictionPropagator(apply_default=apply_default)
return map_integrand_dags(rules, expression, only_integral_type=integral_types)
8 changes: 2 additions & 6 deletions ufl/algorithms/compute_form_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ufl.algorithms.apply_function_pullbacks import apply_function_pullbacks
from ufl.algorithms.apply_geometry_lowering import apply_geometry_lowering
from ufl.algorithms.apply_integral_scaling import apply_integral_scaling
from ufl.algorithms.apply_restrictions import apply_default_restrictions, apply_restrictions
from ufl.algorithms.apply_restrictions import apply_restrictions
from ufl.algorithms.check_arities import check_form_arity
from ufl.algorithms.comparison_checker import do_comparison_check

Expand Down Expand Up @@ -306,10 +306,6 @@ def compute_form_data(
if do_apply_integral_scaling:
form = apply_integral_scaling(form)

# Apply default restriction to fully continuous terminals
if do_apply_default_restrictions:
form = apply_default_restrictions(form)

# Lower abstractions for geometric quantities into a smaller set
# of quantities, allowing the form compiler to deal with a smaller
# set of types and treating geometric quantities like any other
Expand All @@ -334,7 +330,7 @@ def compute_form_data(

# Propagate restrictions to terminals
if do_apply_restrictions:
form = apply_restrictions(form)
form = apply_restrictions(form, apply_default=do_apply_default_restrictions)

# If in real mode, remove any complex nodes introduced during form processing.
if not complex_mode:
Expand Down
Loading