diff --git a/setup.cfg b/setup.cfg index bdcf22ec1..77f81636d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -59,6 +59,10 @@ ci = [flake8] ignore = E501, W504, + # Line break before operator, no longer PEP8. + W503, + # Indentation, can trigger for valid code. + E129, # ambiguous variable name E741 builtins = ufl diff --git a/test/test_algorithms.py b/test/test_algorithms.py index b5484def3..c68bb15c7 100755 --- a/test/test_algorithms.py +++ b/test/test_algorithms.py @@ -9,7 +9,7 @@ import pytest from pprint import * -from ufl import (FiniteElement, TestFunction, TrialFunction, triangle, +from ufl import (FiniteElement, TestFunction, TrialFunction, Matrix, triangle, div, grad, Argument, dx, adjoint, Coefficient, FacetNormal, inner, dot, ds) from ufl.algorithms import (extract_arguments, expand_derivatives, diff --git a/test/test_duals.py b/test/test_duals.py new file mode 100644 index 000000000..8fdf357bb --- /dev/null +++ b/test/test_duals.py @@ -0,0 +1,319 @@ +#!/usr/bin/env py.test +# -*- coding: utf-8 -*- + +from ufl import FiniteElement, FunctionSpace, MixedFunctionSpace, \ + Coefficient, Matrix, Cofunction, FormSum, Argument, Coargument,\ + TestFunction, TrialFunction, Adjoint, Action, \ + action, adjoint, derivative, tetrahedron, triangle, interval, dx +from ufl.constantvalue import Zero +from ufl.form import ZeroBaseForm + +__authors__ = "India Marsden" +__date__ = "2020-12-28 -- 2020-12-28" + +import pytest + +from ufl.domain import default_domain +from ufl.duals import is_primal, is_dual +from ufl.algorithms.ad import expand_derivatives + + +def test_mixed_functionspace(self): + # Domains + domain_3d = default_domain(tetrahedron) + domain_2d = default_domain(triangle) + domain_1d = default_domain(interval) + # Finite elements + f_1d = FiniteElement("CG", interval, 1) + f_2d = FiniteElement("CG", triangle, 1) + f_3d = FiniteElement("CG", tetrahedron, 1) + # Function spaces + V_3d = FunctionSpace(domain_3d, f_3d) + V_2d = FunctionSpace(domain_2d, f_2d) + V_1d = FunctionSpace(domain_1d, f_1d) + + # MixedFunctionSpace = V_3d x V_2d x V_1d + V = MixedFunctionSpace(V_3d, V_2d, V_1d) + # Check sub spaces + assert is_primal(V_3d) + assert is_primal(V_2d) + assert is_primal(V_1d) + assert is_primal(V) + + # Get dual of V_3 + V_dual = V_3d.dual() + + # Test dual functions on MixedFunctionSpace = V_dual x V_2d x V_1d + V = MixedFunctionSpace(V_dual, V_2d, V_1d) + V_mixed_dual = MixedFunctionSpace(V_dual, V_2d.dual(), V_1d.dual()) + + assert is_dual(V_dual) + assert not is_dual(V) + assert is_dual(V_mixed_dual) + + +def test_dual_coefficients(): + domain_2d = default_domain(triangle) + f_2d = FiniteElement("CG", triangle, 1) + V = FunctionSpace(domain_2d, f_2d) + V_dual = V.dual() + + v = Coefficient(V, count=1) + u = Coefficient(V_dual, count=1) + w = Cofunction(V_dual) + + assert is_primal(v) + assert not is_dual(v) + + assert is_dual(u) + assert not is_primal(u) + + assert is_dual(w) + assert not is_primal(w) + + with pytest.raises(ValueError): + x = Cofunction(V) + + +def test_dual_arguments(): + domain_2d = default_domain(triangle) + f_2d = FiniteElement("CG", triangle, 1) + V = FunctionSpace(domain_2d, f_2d) + V_dual = V.dual() + + v = Argument(V, 1) + u = Argument(V_dual, 2) + w = Coargument(V_dual, 3) + + assert is_primal(v) + assert not is_dual(v) + + assert is_dual(u) + assert not is_primal(u) + + assert is_dual(w) + assert not is_primal(w) + + with pytest.raises(ValueError): + x = Coargument(V, 4) + + +def test_addition(): + domain_2d = default_domain(triangle) + f_2d = FiniteElement("CG", triangle, 1) + V = FunctionSpace(domain_2d, f_2d) + f_2d_2 = FiniteElement("CG", triangle, 2) + V2 = FunctionSpace(domain_2d, f_2d_2) + V_dual = V.dual() + + u = TrialFunction(V) + v = TestFunction(V) + + # linear 1-form + L = v * dx + a = Cofunction(V_dual) + res = L + a + assert isinstance(res, FormSum) + assert res + + L = u * v * dx + a = Matrix(V, V) + res = L + a + assert isinstance(res, FormSum) + assert res + + # Check BaseForm._add__ simplification + res += ZeroBaseForm((v, u)) + assert res == a + L + # Check Form._add__ simplification + L += ZeroBaseForm((v,)) + assert L == u * v * dx + # Check BaseForm._add__ simplification + res = ZeroBaseForm((v, u)) + res += a + assert res == a + # Check __neg__ + res = L + res -= ZeroBaseForm((v,)) + assert res == L + + with pytest.raises(ValueError): + # Raise error for incompatible arguments + v2 = TestFunction(V2) + res = L + ZeroBaseForm((v2, u)) + + +def test_scalar_mult(): + domain_2d = default_domain(triangle) + f_2d = FiniteElement("CG", triangle, 1) + V = FunctionSpace(domain_2d, f_2d) + V_dual = V.dual() + + # linear 1-form + a = Cofunction(V_dual) + res = 2 * a + assert isinstance(res, FormSum) + assert res + + a = Matrix(V, V) + res = 2 * a + assert isinstance(res, FormSum) + assert res + + +def test_adjoint(): + domain_2d = default_domain(triangle) + f_2d = FiniteElement("CG", triangle, 1) + V = FunctionSpace(domain_2d, f_2d) + a = Matrix(V, V) + + adj = adjoint(a) + res = 2 * adj + assert isinstance(res, FormSum) + assert res + + res = adjoint(2 * a) + assert isinstance(res, FormSum) + assert isinstance(res.components()[0], Adjoint) + + # Adjoint(Adjoint(.)) = Id + assert adjoint(adj) == a + + +def test_action(): + domain_2d = default_domain(triangle) + f_2d = FiniteElement("CG", triangle, 1) + V = FunctionSpace(domain_2d, f_2d) + domain_1d = default_domain(interval) + f_1d = FiniteElement("CG", interval, 1) + U = FunctionSpace(domain_1d, f_1d) + + a = Matrix(V, U) + b = Matrix(V, U.dual()) + u = Coefficient(U) + u_a = Argument(U, 0) + v = Coefficient(V) + ustar = Cofunction(U.dual()) + u_form = u_a * dx + + res = action(a, u) + assert res + assert len(res.arguments()) < len(a.arguments()) + assert isinstance(res, Action) + + repeat = action(res, v) + assert repeat + assert len(repeat.arguments()) < len(res.arguments()) + + res = action(2 * a, u) + assert isinstance(res, FormSum) + assert isinstance(res.components()[0], Action) + + res = action(b, u_form) + assert res + assert len(res.arguments()) < len(b.arguments()) + + with pytest.raises(TypeError): + res = action(a, v) + + with pytest.raises(TypeError): + res = action(a, ustar) + + b2 = Matrix(V, U.dual()) + ustar2 = Cofunction(U.dual()) + # Check Action left-distributivity with FormSum + res = action(b, ustar + ustar2) + assert res == Action(b, ustar) + Action(b, ustar2) + # Check Action right-distributivity with FormSum + res = action(b + b2, ustar) + assert res == Action(b, ustar) + Action(b2, ustar) + + a2 = Matrix(V, U) + u2 = Coefficient(U) + u3 = Coefficient(U) + # Check Action left-distributivity with Sum + # Add 3 Coefficients to check composition of Sum works fine since u + u2 + u3 => Sum(u, Sum(u2, u3)) + res = action(a, u + u2 + u3) + assert res == Action(a, u3) + Action(a, u) + Action(a, u2) + # Check Action right-distributivity with Sum + res = action(a + a2, u) + assert res == Action(a, u) + Action(a2, u) + + +def test_differentiation(): + domain_2d = default_domain(triangle) + f_2d = FiniteElement("CG", triangle, 1) + V = FunctionSpace(domain_2d, f_2d) + domain_1d = default_domain(interval) + f_1d = FiniteElement("CG", interval, 1) + U = FunctionSpace(domain_1d, f_1d) + + u = Coefficient(U) + v = Argument(U, 0) + vstar = Argument(U.dual(), 0) + + # -- Cofunction -- # + w = Cofunction(U.dual()) + dwdu = expand_derivatives(derivative(w, u)) + assert isinstance(dwdu, ZeroBaseForm) + assert dwdu.arguments() == (Argument(u.ufl_function_space(), 0),) + # Check compatibility with int/float + assert dwdu == 0 + + dwdw = expand_derivatives(derivative(w, w, vstar)) + assert dwdw == vstar + + dudw = expand_derivatives(derivative(u, w)) + # du/dw is a ufl.Zero and not a ZeroBaseForm + # as we are not differentiating a BaseForm + assert isinstance(dudw, Zero) + assert dudw == 0 + + # -- Coargument -- # + dvstardu = expand_derivatives(derivative(vstar, u)) + assert isinstance(dvstardu, ZeroBaseForm) + assert dvstardu.arguments() == vstar.arguments() + (Argument(u.ufl_function_space(), 1),) + # Check compatibility with int/float + assert dvstardu == 0 + + # -- Matrix -- # + M = Matrix(V, U) + dMdu = expand_derivatives(derivative(M, u)) + assert isinstance(dMdu, ZeroBaseForm) + assert dMdu.arguments() == M.arguments() + (Argument(u.ufl_function_space(), 2),) + # Check compatibility with int/float + assert dMdu == 0 + + # -- Action -- # + Ac = Action(M, u) + dAcdu = expand_derivatives(derivative(Ac, u)) + + # Action(dM/du, u) + Action(M, du/du) = Action(M, uhat) since dM/du = 0. + # Multiply by 1 to get a FormSum (type compatibility). + assert dAcdu == 1 * Action(M, v) + + # -- Adjoint -- # + Ad = Adjoint(M) + dAddu = expand_derivatives(derivative(Ad, u)) + # Push differentiation through Adjoint + assert dAddu == 0 + + # -- Form sum -- # + Fs = M + Ac + dFsdu = expand_derivatives(derivative(Fs, u)) + # Distribute differentiation over FormSum components + assert dFsdu == 1 * Action(M, v) + + +def test_zero_base_form_mult(): + domain_2d = default_domain(triangle) + f_2d = FiniteElement("CG", triangle, 1) + V = FunctionSpace(domain_2d, f_2d) + v = Argument(V, 0) + Z = ZeroBaseForm((v, v)) + + u = Coefficient(V) + + Zu = Z * u + assert Zu == action(Z, u) + assert action(Zu, u) == ZeroBaseForm(()) diff --git a/test/test_equals.py b/test/test_equals.py index abc9d0b55..fd9522ead 100755 --- a/test/test_equals.py +++ b/test/test_equals.py @@ -22,7 +22,7 @@ def test_comparison_of_coefficients(): u2 = Coefficient(U, count=2) u2b = Coefficient(Ub, count=2) - # Itentical objects + # Identical objects assert v1 == v1 assert u2 == u2 @@ -36,6 +36,32 @@ def test_comparison_of_coefficients(): assert not v1 == u1 assert not v2 == u2 +def test_comparison_of_cofunctions(): + V = FiniteElement("CG", triangle, 1) + U = FiniteElement("CG", triangle, 2) + Ub = FiniteElement("CG", triangle, 2) + v1 = Cofunction(V, count=1) + v1b = Cofunction(V, count=1) + v2 = Cofunction(V, count=2) + u1 = Cofunction(U, count=1) + u2 = Cofunction(U, count=2) + u2b = Cofunction(Ub, count=2) + + # Identical objects + assert v1 == v1 + assert u2 == u2 + + # Equal but distinct objects + assert v1 == v1b + assert u2 == u2b + + # Different objects + assert not v1 == v2 + assert not u1 == u2 + assert not v1 == u1 + assert not v2 == u2 + + def test_comparison_of_products(): V = FiniteElement("CG", triangle, 1) diff --git a/test/test_form.py b/test/test_form.py index 10141d51a..71f4a2a20 100755 --- a/test/test_form.py +++ b/test/test_form.py @@ -3,6 +3,7 @@ import pytest from ufl import * +from ufl.form import BaseForm @pytest.fixture @@ -134,3 +135,28 @@ def test_form_call(): a = u*v*dx M = eval("(a @ f) @ g") assert M == g*f*dx + +def test_formsum(mass): + V = FiniteElement("CG", triangle, 1) + v = Cofunction(V) + + assert(v + mass) + assert(mass + v) + assert(isinstance((mass+v), FormSum)) + + assert(len((mass + v + v).components()) == 3) + # Variational forms are summed appropriately + assert(len((mass + v + mass).components()) == 2) + + assert(v - mass) + assert(mass - v) + assert(isinstance((mass+v), FormSum)) + + assert(-v) + assert(isinstance(-v, BaseForm)) + assert((-v).weights()[0] == -1) + + assert(2 * v) + assert(isinstance(2 * v, BaseForm)) + assert((2 * v).weights()[0] == 2) + diff --git a/test/test_mixed_function_space.py b/test/test_mixed_function_space.py index 34d445c84..4cb5abb3a 100644 --- a/test/test_mixed_function_space.py +++ b/test/test_mixed_function_space.py @@ -78,3 +78,17 @@ def test_mixed_functionspace(self): assert ( extract_blocks(f,0) == f_3 ) assert ( extract_blocks(f,1) == f_2 ) assert ( extract_blocks(f,2) == f_1 ) + + # Test dual space method + V_dual = V.dual() + assert( V_dual.num_sub_spaces() == 3 ) + assert( V_dual.ufl_sub_space(0) == V_3d.dual() ) + assert( V_dual.ufl_sub_space(1) == V_2d.dual() ) + assert( V_dual.ufl_sub_space(2) == V_1d.dual() ) + + V_dual = V.dual(*[0,2]) + assert( V_dual.num_sub_spaces() == 3 ) + assert( V_dual.ufl_sub_space(0) == V_3d.dual() ) + assert( V_dual.ufl_sub_space(1) == V_2d ) + assert( V_dual.ufl_sub_space(2) == V_1d.dual() ) + diff --git a/ufl/__init__.py b/ufl/__init__.py index 0517989fc..7bb84df1c 100644 --- a/ufl/__init__.py +++ b/ufl/__init__.py @@ -288,13 +288,22 @@ from ufl.functionspace import FunctionSpace, MixedFunctionSpace # Arguments -from ufl.argument import Argument, TestFunction, TrialFunction, \ +from ufl.argument import Argument, Coargument, TestFunction, TrialFunction, \ Arguments, TestFunctions, TrialFunctions # Coefficients -from ufl.coefficient import Coefficient, Coefficients +from ufl.coefficient import Coefficient, Cofunction, Coefficients from ufl.constant import Constant, VectorConstant, TensorConstant +# Matrices +from ufl.matrix import Matrix + +# Adjoints +from ufl.adjoint import Adjoint + +# Actions +from ufl.action import Action + # Split function from ufl.split_functions import split @@ -335,7 +344,7 @@ from ufl.measure import Measure, register_integral_type, integral_types, custom_integral_types # Form class -from ufl.form import Form, replace_integral_domains +from ufl.form import Form, BaseForm, FormSum, ZeroBaseForm, replace_integral_domains # Integral classes from ufl.integral import Integral @@ -383,9 +392,10 @@ 'BrokenElement', "WithMapping", 'register_element', 'show_elements', 'FunctionSpace', 'MixedFunctionSpace', - 'Argument', 'TestFunction', 'TrialFunction', + 'Argument','Coargument', 'TestFunction', 'TrialFunction', 'Arguments', 'TestFunctions', 'TrialFunctions', - 'Coefficient', 'Coefficients', + 'Coefficient', 'Cofunction', 'Coefficients', + 'Matrix', 'Adjoint', 'Action', 'Constant', 'VectorConstant', 'TensorConstant', 'split', 'PermutationSymbol', 'Identity', 'zero', 'as_ufl', @@ -407,7 +417,7 @@ 'Dx', 'grad', 'div', 'curl', 'rot', 'nabla_grad', 'nabla_div', 'Dn', 'exterior_derivative', 'jump', 'avg', 'cell_avg', 'facet_avg', 'elem_mult', 'elem_div', 'elem_pow', 'elem_op', - 'Form', + 'Form','FormSum', 'ZeroBaseForm', 'Integral', 'Measure', 'register_integral_type', 'integral_types', 'custom_integral_types', 'replace', 'replace_integral_domains', 'derivative', 'action', 'energy_norm', 'rhs', 'lhs', 'extract_blocks', 'system', 'functional', 'adjoint', 'sensitivity_rhs', diff --git a/ufl/action.py b/ufl/action.py new file mode 100644 index 000000000..7b32b7bf9 --- /dev/null +++ b/ufl/action.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +"""This module defines the Action class.""" + +# Copyright (C) 2021 India Marsden +# +# This file is part of UFL (https://www.fenicsproject.org) +# +# SPDX-License-Identifier: LGPL-3.0-or-later +# +# Modified by Nacime Bouziani, 2021-2022. + +from ufl.form import BaseForm, FormSum, Form, ZeroBaseForm +from ufl.core.ufl_type import ufl_type +from ufl.algebra import Sum +from ufl.argument import Argument +from ufl.coefficient import BaseCoefficient, Coefficient, Cofunction +from ufl.differentiation import CoefficientDerivative +from ufl.matrix import Matrix + +# --- The Action class represents the action of a numerical object that needs +# to be computed at assembly time --- + + +@ufl_type() +class Action(BaseForm): + """UFL base form type: respresents the action of an object on another. + For example: + res = Ax + A would be the first argument, left and x would be the second argument, + right. + + Action objects will result when the action of an assembled object + (e.g. a Matrix) is taken. This delays the evaluation of the action until + assembly occurs. + """ + + __slots__ = ( + "_left", + "_right", + "ufl_operands", + "_repr", + "_arguments", + "_hash") + + def __getnewargs__(self): + return (self._left, self._right) + + def __new__(cls, *args, **kw): + left, right = args + + # Check trivial case + if left == 0 or right == 0: + # Check compatibility of function spaces + _check_function_spaces(left, right) + # Still need to work out the ZeroBaseForm arguments. + new_arguments = _get_action_form_arguments(left, right) + return ZeroBaseForm(new_arguments) + + if isinstance(left, (FormSum, Sum)): + # Action distributes over sums on the LHS + return FormSum(*[(Action(component, right), 1) + for component in left.ufl_operands]) + if isinstance(right, (FormSum, Sum)): + # Action also distributes over sums on the RHS + return FormSum(*[(Action(left, component), 1) + for component in right.ufl_operands]) + + return super(Action, cls).__new__(cls) + + def __init__(self, left, right): + BaseForm.__init__(self) + + self._left = left + self._right = right + self.ufl_operands = (self._left, self._right) + + # Check compatibility of function spaces + _check_function_spaces(left, right) + + self._repr = "Action(%s, %s)" % (repr(self._left), repr(self._right)) + self._hash = None + + def ufl_function_spaces(self): + "Get the tuple of function spaces of the underlying form" + if isinstance(self._right, Form): + return self._left.ufl_function_spaces()[:-1] \ + + self._right.ufl_function_spaces()[1:] + elif isinstance(self._right, Coefficient): + return self._left.ufl_function_spaces()[:-1] + + def left(self): + return self._left + + def right(self): + return self._right + + def _analyze_form_arguments(self): + """Compute the Arguments of this Action. + + The highest number Argument of the left operand and the lowest number + Argument of the right operand are consumed by the action. + """ + self._arguments = _get_action_form_arguments(self._left, self._right) + + def equals(self, other): + if type(other) is not Action: + return False + if self is other: + return True + return (self._left == other._left and self._right == other._right) + + def __str__(self): + return "Action(%s, %s)" % (str(self._left), str(self._right)) + + def __repr__(self): + return self._repr + + def __hash__(self): + "Hash code for use in dicts " + if self._hash is None: + self._hash = hash(("Action", + hash(self._right), + hash(self._left))) + return self._hash + + +def _check_function_spaces(left, right): + """Check if the function spaces of left and right match.""" + + if isinstance(right, CoefficientDerivative): + # Action differentiation pushes differentiation through + # right as a consequence of Leibniz formula. + right, *_ = right.ufl_operands + + if isinstance(right, (Form, Action, Matrix, ZeroBaseForm)): + if (left.arguments()[-1].ufl_function_space().dual() + != right.arguments()[0].ufl_function_space()): + + raise TypeError("Incompatible function spaces in Action") + elif isinstance(right, (Coefficient, Cofunction, Argument)): + if (left.arguments()[-1].ufl_function_space() + != right.ufl_function_space()): + + raise TypeError("Incompatible function spaces in Action") + else: + raise TypeError("Incompatible argument in Action: %s" % type(right)) + + +def _get_action_form_arguments(left, right): + """Perform argument contraction to work out the arguments of Action""" + + if isinstance(right, CoefficientDerivative): + # Action differentiation pushes differentiation through + # right as a consequence of Leibniz formula. + right, *_ = right.ufl_operands + + if isinstance(right, BaseForm): + return left.arguments()[:-1] + right.arguments()[1:] + elif isinstance(right, BaseCoefficient): + return left.arguments()[:-1] + elif isinstance(right, Argument): + return left.arguments()[:-1] + (right,) + else: + raise TypeError diff --git a/ufl/adjoint.py b/ufl/adjoint.py new file mode 100644 index 000000000..8129596de --- /dev/null +++ b/ufl/adjoint.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +"""This module defines the Adjoint class.""" + +# Copyright (C) 2021 India Marsden +# +# This file is part of UFL (https://www.fenicsproject.org) +# +# SPDX-License-Identifier: LGPL-3.0-or-later +# +# Modified by Nacime Bouziani, 2021-2022. + +from ufl.form import BaseForm, FormSum, ZeroBaseForm +from ufl.core.ufl_type import ufl_type +# --- The Adjoint class represents the adjoint of a numerical object that +# needs to be computed at assembly time --- + + +@ufl_type() +class Adjoint(BaseForm): + """UFL base form type: represents the adjoint of an object. + + Adjoint objects will result when the adjoint of an assembled object + (e.g. a Matrix) is taken. This delays the evaluation of the adjoint until + assembly occurs. + """ + + __slots__ = ( + "_form", + "_repr", + "_arguments", + "ufl_operands", + "_hash") + + def __getnewargs__(self): + return (self._form) + + def __new__(cls, *args, **kw): + form = args[0] + # Check trivial case: This is not a ufl.Zero but a ZeroBaseForm! + if form == 0: + # Swap the arguments + return ZeroBaseForm(form.arguments()[::-1]) + + if isinstance(form, Adjoint): + return form._form + elif isinstance(form, FormSum): + # Adjoint distributes over sums + return FormSum(*[(Adjoint(component), 1) + for component in form.components()]) + + return super(Adjoint, cls).__new__(cls) + + def __init__(self, form): + BaseForm.__init__(self) + + if len(form.arguments()) != 2: + raise ValueError("Can only take Adjoint of a 2-form.") + + self._form = form + self.ufl_operands = (self._form,) + self._hash = None + self._repr = "Adjoint(%s)" % repr(self._form) + + def ufl_function_spaces(self): + "Get the tuple of function spaces of the underlying form" + return self._form.ufl_function_spaces() + + def form(self): + return self._form + + def _analyze_form_arguments(self): + """The arguments of adjoint are the reverse of the form arguments.""" + self._arguments = self._form.arguments()[::-1] + + def equals(self, other): + if type(other) is not Adjoint: + return False + if self is other: + return True + return (self._form == other._form) + + def __str__(self): + return "Adjoint(%s)" % str(self._form) + + def __repr__(self): + return self._repr + + def __hash__(self): + """Hash code for use in dicts.""" + if self._hash is None: + self._hash = hash(("Adjoint", hash(self._form))) + return self._hash diff --git a/ufl/algorithms/ad.py b/ufl/algorithms/ad.py index faaa8e615..4755070ae 100644 --- a/ufl/algorithms/ad.py +++ b/ufl/algorithms/ad.py @@ -9,11 +9,14 @@ # # Modified by Anders Logg, 2009. +import warnings + +from ufl.adjoint import Adjoint from ufl.algorithms.apply_algebra_lowering import apply_algebra_lowering from ufl.algorithms.apply_derivatives import apply_derivatives -def expand_derivatives(form): +def expand_derivatives(form, **kwargs): """Expand all derivatives of expr. In the returned expression g which is mathematically @@ -21,6 +24,18 @@ def expand_derivatives(form): or CoefficientDerivative objects left, and Grad objects have been propagated to Terminal nodes. """ + # For a deprecation period (I see that dolfin-adjoint passes some + # args here) + if kwargs: + warnings("Deprecation: expand_derivatives no longer takes any keyword arguments") + + if isinstance(form, Adjoint): + dform = expand_derivatives(form._form) + if dform == 0: + return dform + # Adjoint is taken on a 3-form which can't happen + raise NotImplementedError('Adjoint derivative is not supported.') + # Lower abstractions for tensor-algebra types into index notation form = apply_algebra_lowering(form) diff --git a/ufl/algorithms/analysis.py b/ufl/algorithms/analysis.py index 4d9265769..ab74c1582 100644 --- a/ufl/algorithms/analysis.py +++ b/ufl/algorithms/analysis.py @@ -15,10 +15,11 @@ from ufl.log import error from ufl.utils.sorting import sorted_by_count, topological_sorting -from ufl.core.terminal import Terminal, FormArgument -from ufl.argument import Argument -from ufl.coefficient import Coefficient +from ufl.core.terminal import Terminal +from ufl.argument import BaseArgument +from ufl.coefficient import BaseCoefficient from ufl.constant import Constant +from ufl.form import BaseForm, Form from ufl.algorithms.traversal import iter_expressions from ufl.corealg.traversal import unique_pre_traversal, traverse_unique_terminals @@ -45,29 +46,40 @@ def unique_tuple(objects): def __unused__extract_classes(a): """Build a set of all unique Expr subclasses used in a. - The argument a can be a Form, Integral or Expr.""" + The argument a can be a BaseForm, Integral or Expr.""" return set(o._ufl_class_ for e in iter_expressions(a) for o in unique_pre_traversal(e)) -def extract_type(a, ufl_type): - """Build a set of all objects of class ufl_type found in a. - The argument a can be a Form, Integral or Expr.""" - if issubclass(ufl_type, Terminal): +def extract_type(a, ufl_types): + """Build a set of all objects found in a whose class is in ufl_types. + The argument a can be a BaseForm, Integral or Expr.""" + + if not isinstance(ufl_types, (list, tuple)): + ufl_types = (ufl_types,) + + # BaseForms that aren't forms only have arguments + if isinstance(a, BaseForm) and not isinstance(a, Form): + if any(issubclass(t, BaseArgument) for t in ufl_types): + return set(a.arguments()) + else: + return set() + + if all(issubclass(t, Terminal) for t in ufl_types): # Optimization return set(o for e in iter_expressions(a) for o in traverse_unique_terminals(e) - if isinstance(o, ufl_type)) + if any(isinstance(o, t) for t in ufl_types)) else: return set(o for e in iter_expressions(a) for o in unique_pre_traversal(e) - if isinstance(o, ufl_type)) + if any(isinstance(o, t) for t in ufl_types)) def has_type(a, ufl_type): """Return if an object of class ufl_type can be found in a. - The argument a can be a Form, Integral or Expr.""" + The argument a can be a BaseForm, Integral or Expr.""" if issubclass(ufl_type, Terminal): # Optimization traversal = traverse_unique_terminals @@ -78,7 +90,7 @@ def has_type(a, ufl_type): def has_exact_type(a, ufl_type): """Return if an object of class ufl_type can be found in a. - The argument a can be a Form, Integral or Expr.""" + The argument a can be a BaseForm, Integral or Expr.""" tc = ufl_type._ufl_typecode_ if issubclass(ufl_type, Terminal): # Optimization @@ -90,14 +102,14 @@ def has_exact_type(a, ufl_type): def extract_arguments(a): """Build a sorted list of all arguments in a, - which can be a Form, Integral or Expr.""" - return _sorted_by_number_and_part(extract_type(a, Argument)) + which can be a BaseForm, Integral or Expr.""" + return _sorted_by_number_and_part(extract_type(a, BaseArgument)) def extract_coefficients(a): """Build a sorted list of all coefficients in a, - which can be a Form, Integral or Expr.""" - return sorted_by_count(extract_type(a, Coefficient)) + which can be a BaseForm, Integral or Expr.""" + return sorted_by_count(extract_type(a, BaseCoefficient)) def extract_constants(a): @@ -107,15 +119,15 @@ def extract_constants(a): def extract_arguments_and_coefficients(a): """Build two sorted lists of all arguments and coefficients - in a, which can be a Form, Integral or Expr.""" + in a, which can be BaseForm, Integral or Expr.""" # This function is faster than extract_arguments + extract_coefficients # for large forms, and has more validation built in. - # Extract lists of all form argument instances - terminals = extract_type(a, FormArgument) - arguments = [f for f in terminals if isinstance(f, Argument)] - coefficients = [f for f in terminals if isinstance(f, Coefficient)] + # Extract lists of all BaseArgument and BaseCoefficient instances + base_coeff_and_args = extract_type(a, (BaseArgument, BaseCoefficient)) + arguments = [f for f in base_coeff_and_args if isinstance(f, BaseArgument)] + coefficients = [f for f in base_coeff_and_args if isinstance(f, BaseCoefficient)] # Build number,part: instance mappings, should be one to one bfnp = dict((f, (f.number(), f.part())) for f in arguments) diff --git a/ufl/algorithms/apply_algebra_lowering.py b/ufl/algorithms/apply_algebra_lowering.py index 24fc2fbb6..66846fb4e 100644 --- a/ufl/algorithms/apply_algebra_lowering.py +++ b/ufl/algorithms/apply_algebra_lowering.py @@ -29,7 +29,7 @@ class LowerCompoundAlgebra(MultiFunction): def __init__(self): MultiFunction.__init__(self) - expr = MultiFunction.reuse_if_untouched + ufl_type = MultiFunction.reuse_if_untouched # ------------ Compound tensor operators diff --git a/ufl/algorithms/apply_derivatives.py b/ufl/algorithms/apply_derivatives.py index b18d77f69..e6cdaade3 100644 --- a/ufl/algorithms/apply_derivatives.py +++ b/ufl/algorithms/apply_derivatives.py @@ -33,6 +33,7 @@ from ufl.tensors import (as_scalar, as_scalars, as_tensor, unit_indexed_tensor, unwrap_list_tensor) +from ufl.form import ZeroBaseForm # TODO: Add more rulesets? # - DivRuleset # - CurlRuleset @@ -1037,6 +1038,31 @@ def coordinate_derivative(self, o): o = o.ufl_operands return CoordinateDerivative(map_expr_dag(self, o[0]), o[1], o[2], o[3]) + # -- Handlers for BaseForm objects -- # + + def cofunction(self, o): + # Same rule than for Coefficient except that we use a Coargument. + # The coargument is already attached to the class (self._v) + # which `self.coefficient` relies on. + dc = self.coefficient(o) + if dc == 0: + # Convert ufl.Zero into ZeroBaseForm + return ZeroBaseForm(self._v) + return dc + + def coargument(self, o): + # Same rule than for Argument (da/dw == 0). + dc = self.argument(o) + if dc == 0: + # Convert ufl.Zero into ZeroBaseForm + return ZeroBaseForm(o.arguments() + self._v) + return dc + + def matrix(self, M): + # Matrix rule: D_w[v](M) = v if M == w else 0 + # We can't differentiate wrt a matrix so always return zero in the appropriate space + return ZeroBaseForm(M.arguments() + self._v) + class DerivativeRuleDispatcher(MultiFunction): def __init__(self): @@ -1051,7 +1077,7 @@ def terminal(self, o): def derivative(self, o): error("Missing derivative handler for {0}.".format(type(o).__name__)) - expr = MultiFunction.reuse_if_untouched + ufl_type = MultiFunction.reuse_if_untouched def grad(self, o, f): rules = GradRuleset(o.ufl_shape[-1]) diff --git a/ufl/algorithms/formtransformations.py b/ufl/algorithms/formtransformations.py index d11961ca3..765b5163b 100644 --- a/ufl/algorithms/formtransformations.py +++ b/ufl/algorithms/formtransformations.py @@ -428,6 +428,8 @@ def compute_energy_norm(form, coefficient): Arguments, and one additional Coefficient at the end if no coefficient has been provided. """ + from ufl.formoperators import action # Delayed import to avoid circularity + arguments = form.arguments() parts = [arg.part() for arg in arguments] @@ -447,7 +449,7 @@ def compute_energy_norm(form, coefficient): if coefficient.ufl_function_space() != U: error("Trying to compute action of form on a " "coefficient in an incompatible element space.") - return replace(form, {u: coefficient, v: coefficient}) + return action(action(form, coefficient), coefficient) def compute_form_adjoint(form, reordered_arguments=None): diff --git a/ufl/algorithms/map_integrands.py b/ufl/algorithms/map_integrands.py index 0846218f4..728b35f47 100644 --- a/ufl/algorithms/map_integrands.py +++ b/ufl/algorithms/map_integrands.py @@ -15,7 +15,9 @@ from ufl.core.expr import Expr from ufl.corealg.map_dag import map_expr_dag from ufl.integral import Integral -from ufl.form import Form +from ufl.form import Form, BaseForm, FormSum, ZeroBaseForm +from ufl.action import Action +from ufl.adjoint import Adjoint from ufl.constantvalue import Zero @@ -35,7 +37,25 @@ def map_integrands(function, form, only_integral_type=None): return itg.reconstruct(function(itg.integrand())) else: return itg - elif isinstance(form, Expr): + elif isinstance(form, FormSum): + mapped_components = [map_integrands(function, component, only_integral_type) + for component in form.components()] + nonzero_components = [(component, 1) for component in mapped_components + # Catch ufl.Zero and ZeroBaseForm + if component != 0] + return FormSum(*nonzero_components) + elif isinstance(form, Adjoint): + # Zeros are caught inside `Adjoint.__new__` + return Adjoint(map_integrands(function, form._form, only_integral_type)) + elif isinstance(form, Action): + left = map_integrands(function, form._left, only_integral_type) + right = map_integrands(function, form._right, only_integral_type) + # Zeros are caught inside `Action.__new__` + return Action(left, right) + elif isinstance(form, ZeroBaseForm): + arguments = tuple(map_integrands(function, arg, only_integral_type) for arg in form._arguments) + return ZeroBaseForm(arguments) + elif isinstance(form, (Expr, BaseForm)): integrand = form return function(integrand) else: diff --git a/ufl/algorithms/replace.py b/ufl/algorithms/replace.py index 39c4bd9e3..e3a8fb815 100644 --- a/ufl/algorithms/replace.py +++ b/ufl/algorithms/replace.py @@ -24,7 +24,7 @@ def __init__(self, mapping): if not all(k.ufl_shape == v.ufl_shape for k, v in mapping.items()): error("Replacement expressions must have the same shape as what they replace.") - def expr(self, o, *args): + def ufl_type(self, o, *args): try: return self.mapping[o] except KeyError: diff --git a/ufl/algorithms/transformer.py b/ufl/algorithms/transformer.py index 7d4f9a87a..5587aff35 100644 --- a/ufl/algorithms/transformer.py +++ b/ufl/algorithms/transformer.py @@ -16,6 +16,7 @@ from ufl.algorithms.map_integrands import map_integrands from ufl.classes import Variable, all_ufl_classes +from ufl.core.ufl_type import UFLType from ufl.log import error @@ -50,7 +51,13 @@ def __init__(self, variable_cache=None): for c in classobject.mro(): # Register classobject with handler for the first # encountered superclass - handler_name = c._ufl_handler_name_ + try: + handler_name = c._ufl_handler_name_ + except AttributeError as attribute_error: + if type(classobject) is not UFLType: + raise attribute_error + # Default handler name for UFL types + handler_name = UFLType._ufl_handler_name_ function = getattr(self, handler_name, None) if function: cache_data[ @@ -135,8 +142,8 @@ def always_reconstruct(self, o, *operands): "Always reconstruct expr." return o._ufl_expr_reconstruct_(*operands) - # Set default behaviour for any Expr - expr = undefined + # Set default behaviour for any UFLType + ufl_type = undefined # Set default behaviour for any Terminal terminal = reuse diff --git a/ufl/algorithms/traversal.py b/ufl/algorithms/traversal.py index 4623257e7..809ca1ca3 100644 --- a/ufl/algorithms/traversal.py +++ b/ufl/algorithms/traversal.py @@ -12,7 +12,9 @@ from ufl.log import error from ufl.core.expr import Expr from ufl.integral import Integral -from ufl.form import Form +from ufl.action import Action +from ufl.adjoint import Adjoint +from ufl.form import Form, FormSum, BaseForm # --- Traversal utilities --- @@ -24,11 +26,16 @@ def iter_expressions(a): - a is an Expr: (a,) - a is an Integral: the integrand expression of a - a is a Form: all integrand expressions of all integrals + - a is a FormSum: the components of a + - a is an Action: the left and right component of a + - a is an Adjoint: the underlying form of a """ if isinstance(a, Form): return (itg.integrand() for itg in a.integrals()) elif isinstance(a, Integral): return (a.integrand(),) - elif isinstance(a, Expr): + elif isinstance(a, (FormSum, Adjoint, Action)): + return tuple(e for op in a.ufl_operands for e in iter_expressions(op)) + elif isinstance(a, (Expr, BaseForm)): return (a,) error("Not an UFL type: %s" % str(type(a))) diff --git a/ufl/argument.py b/ufl/argument.py index 27f0a4de4..4c53c00d8 100644 --- a/ufl/argument.py +++ b/ufl/argument.py @@ -19,7 +19,9 @@ from ufl.split_functions import split from ufl.finiteelement import FiniteElementBase from ufl.domain import default_domain +from ufl.form import BaseForm from ufl.functionspace import AbstractFunctionSpace, FunctionSpace, MixedFunctionSpace +from ufl.duals import is_primal, is_dual # Export list for ufl.classes (TODO: not actually classes: drop? these are in ufl.*) __all_classes__ = ["TestFunction", "TrialFunction", "TestFunctions", "TrialFunctions"] @@ -27,19 +29,15 @@ # --- Class representing an argument (basis function) in a form --- -@ufl_type() -class Argument(FormArgument): +class BaseArgument(object): """UFL value: Representation of an argument to a form.""" - __slots__ = ( - "_ufl_function_space", - "_ufl_shape", - "_number", - "_part", - "_repr", - ) + __slots__ = () + _ufl_is_abstract_ = True + + def __getnewargs__(self): + return (self._ufl_function_space, self._number, self._part) def __init__(self, function_space, number, part=None): - FormArgument.__init__(self) if isinstance(function_space, FiniteElementBase): # For legacy support for UFL files using cells, we map the cell to @@ -60,7 +58,7 @@ def __init__(self, function_space, number, part=None): self._number = number self._part = part - self._repr = "Argument(%s, %s, %s)" % ( + self._repr = "BaseArgument(%s, %s, %s)" % ( repr(self._ufl_function_space), repr(self._number), repr(self._part)) @property @@ -138,6 +136,95 @@ def __eq__(self, other): self._part == other._part and self._ufl_function_space == other._ufl_function_space) + +@ufl_type() +class Argument(FormArgument, BaseArgument): + """UFL value: Representation of an argument to a form.""" + __slots__ = ( + "_ufl_function_space", + "_ufl_shape", + "_number", + "_part", + "_repr", + ) + + _primal = True + _dual = False + + __getnewargs__ = BaseArgument.__getnewargs__ + __str__ = BaseArgument.__str__ + _ufl_signature_data_ = BaseArgument._ufl_signature_data_ + + def __new__(cls, *args, **kw): + if args[0] and is_dual(args[0]): + return Coargument(*args, **kw) + return super().__new__(cls) + + def __init__(self, function_space, number, part=None): + FormArgument.__init__(self) + BaseArgument.__init__(self, function_space, number, part) + + self._repr = "Argument(%s, %s, %s)" % ( + repr(self._ufl_function_space), repr(self._number), repr(self._part)) + + def ufl_domains(self): + return BaseArgument.ufl_domains(self) + + def __repr__(self): + return self._repr + + +@ufl_type() +class Coargument(BaseForm, BaseArgument): + """UFL value: Representation of an argument to a form in a dual space.""" + __slots__ = ( + "_ufl_function_space", + "_ufl_shape", + "_arguments", + "ufl_operands", + "_number", + "_part", + "_repr", + "_hash" + ) + + _primal = False + _dual = True + + def __new__(cls, *args, **kw): + if args[0] and is_primal(args[0]): + raise ValueError('ufl.Coargument takes in a dual space! If you want to define an argument in the primal space you should use ufl.Argument.') + return super().__new__(cls) + + def __init__(self, function_space, number, part=None): + BaseArgument.__init__(self, function_space, number, part) + BaseForm.__init__(self) + + self.ufl_operands = () + self._hash = None + self._repr = "Coargument(%s, %s, %s)" % ( + repr(self._ufl_function_space), repr(self._number), repr(self._part)) + + def _analyze_form_arguments(self): + "Analyze which Argument and Coefficient objects can be found in the form." + # Define canonical numbering of arguments and coefficients + self._arguments = (Argument(self._ufl_function_space, 0),) + + def equals(self, other): + if type(other) is not Coargument: + return False + if self is other: + return True + return (self._ufl_function_space == other._ufl_function_space and + self._number == other._number and self._part == other._part) + + def __hash__(self): + """Hash code for use in dicts.""" + return hash(("Coargument", + hash(self._ufl_function_space), + self._number, + self._part)) + # --- Helper functions for pretty syntax --- diff --git a/ufl/coefficient.py b/ufl/coefficient.py index dd6936cab..9caed09bd 100644 --- a/ufl/coefficient.py +++ b/ufl/coefficient.py @@ -18,24 +18,29 @@ from ufl.finiteelement import FiniteElementBase from ufl.domain import default_domain from ufl.functionspace import AbstractFunctionSpace, FunctionSpace, MixedFunctionSpace +from ufl.form import BaseForm from ufl.split_functions import split from ufl.utils.counted import counted_init +from ufl.duals import is_primal, is_dual # --- The Coefficient class represents a coefficient in a form --- -@ufl_type() -class Coefficient(FormArgument): - """UFL form argument type: Representation of a form coefficient.""" +class BaseCoefficient(object): + """UFL form argument type: Parent Representation of a form coefficient.""" # Slots are disabled here because they cause trouble in PyDOLFIN # multiple inheritance pattern: # __slots__ = ("_count", "_ufl_function_space", "_repr", "_ufl_shape") _ufl_noslots_ = True + __slots__ = () _globalcount = 0 + _ufl_is_abstract_ = True + + def __getnewargs__(self): + return (self._ufl_function_space, self._count) def __init__(self, function_space, count=None): - FormArgument.__init__(self) counted_init(self, count, Coefficient) if isinstance(function_space, FiniteElementBase): @@ -50,7 +55,7 @@ def __init__(self, function_space, count=None): self._ufl_function_space = function_space self._ufl_shape = function_space.ufl_element().value_shape() - self._repr = "Coefficient(%s, %s)" % ( + self._repr = "BaseCoefficient(%s, %s)" % ( repr(self._ufl_function_space), repr(self._count)) def count(self): @@ -93,6 +98,94 @@ def __str__(self): def __repr__(self): return self._repr + def __eq__(self, other): + if not isinstance(other, BaseCoefficient): + return False + if self is other: + return True + return (self._count == other._count and + self._ufl_function_space == other._ufl_function_space) + + +@ufl_type() +class Cofunction(BaseCoefficient, BaseForm): + """UFL form argument type: Representation of a form coefficient from a dual space.""" + + __slots__ = ( + "_count", + "_arguments", + "_ufl_function_space", + "ufl_operands", + "_repr", + "_ufl_shape", + "_hash" + ) + # _globalcount = 0 + _primal = False + _dual = True + + def __new__(cls, *args, **kw): + if args[0] and is_primal(args[0]): + raise ValueError('ufl.Cofunction takes in a dual space. If you want to define a coefficient in the primal space you should use ufl.Coefficient.') + return super().__new__(cls) + + def __init__(self, function_space, count=None): + BaseCoefficient.__init__(self, function_space, count) + BaseForm.__init__(self) + + self.ufl_operands = () + self._hash = None + self._repr = "Cofunction(%s, %s)" % ( + repr(self._ufl_function_space), repr(self._count)) + + def equals(self, other): + if type(other) is not Cofunction: + return False + if self is other: + return True + return (self._count == other._count and + self._ufl_function_space == other._ufl_function_space) + + def __hash__(self): + """Hash code for use in dicts.""" + return hash(("Cofunction", + hash(self._ufl_function_space), + self._count)) + + def _analyze_form_arguments(self): + "Analyze which Argument and Coefficient objects can be found in the form." + # Define canonical numbering of arguments and coefficients + self._arguments = () + + +@ufl_type() +class Coefficient(FormArgument, BaseCoefficient): + """UFL form argument type: Representation of a form coefficient.""" + + _ufl_noslots_ = True + _globalcount = 0 + _primal = True + _dual = False + + __getnewargs__ = BaseCoefficient.__getnewargs__ + __str__ = BaseCoefficient.__str__ + _ufl_signature_data_ = BaseCoefficient._ufl_signature_data_ + + def __new__(cls, *args, **kw): + if args[0] and is_dual(args[0]): + return Cofunction(*args, **kw) + return super().__new__(cls) + + def __init__(self, function_space, count=None): + FormArgument.__init__(self) + BaseCoefficient.__init__(self, function_space, count) + + self._repr = "Coefficient(%s, %s)" % ( + repr(self._ufl_function_space), repr(self._count)) + + def ufl_domains(self): + return BaseCoefficient.ufl_domains(self) + def __eq__(self, other): if not isinstance(other, Coefficient): return False @@ -101,6 +194,9 @@ def __eq__(self, other): return (self._count == other._count and self._ufl_function_space == other._ufl_function_space) + def __repr__(self): + return self._repr + # --- Helper functions for subfunctions on mixed elements --- @@ -108,7 +204,7 @@ def Coefficients(function_space): """UFL value: Create a Coefficient in a mixed space, and return a tuple with the function components corresponding to the subelements.""" if isinstance(function_space, MixedFunctionSpace): - return [Coefficient(function_space.ufl_sub_space(i)) - for i in range(function_space.num_sub_spaces())] + return [Coefficient(fs) if is_primal(fs) else Cofunction(fs) + for fs in function_space.num_sub_spaces()] else: return split(Coefficient(function_space)) diff --git a/ufl/constantvalue.py b/ufl/constantvalue.py index 0f0a311a7..b8cf5add1 100644 --- a/ufl/constantvalue.py +++ b/ufl/constantvalue.py @@ -12,6 +12,7 @@ from math import atan2 +import ufl from ufl.log import error, UFLValueError from ufl.core.expr import Expr from ufl.core.terminal import Terminal @@ -425,7 +426,7 @@ def __eps(self, x): def as_ufl(expression): "Converts expression to an Expr if possible." - if isinstance(expression, Expr): + if isinstance(expression, (Expr, ufl.BaseForm)): return expression elif isinstance(expression, complex): return ComplexValue(expression) diff --git a/ufl/core/expr.py b/ufl/core/expr.py index e6c355603..b856c3d61 100644 --- a/ufl/core/expr.py +++ b/ufl/core/expr.py @@ -22,11 +22,12 @@ import warnings from ufl.log import error +from ufl.core.ufl_type import UFLType, update_ufl_type_attributes -# --- The base object for all UFL expression tree nodes --- +# --- The base object for all UFL expression tree nodes --- -class Expr(object): +class Expr(object, metaclass=UFLType): """Base class for all UFL expression types. *Instance properties* @@ -130,22 +131,10 @@ def __init__(self): # implement for this type in a multifunction. _ufl_handler_name_ = "expr" - # The integer typecode, a contiguous index different for each - # type. This is used for fast lookup into e.g. multifunction - # handler tables. - _ufl_typecode_ = 0 - # Number of operands, "varying" for some types, or None if not # applicable for abstract types. _ufl_num_ops_ = None - # Type trait: If the type is abstract. An abstract class cannot - # be instantiated and does not need all properties specified. - _ufl_is_abstract_ = True - - # Type trait: If the type is terminal. - _ufl_is_terminal_ = None - # Type trait: If the type is a literal. _ufl_is_literal_ = None @@ -229,15 +218,6 @@ def __init__(self): # --- Global variables for collecting all types --- - # A global counter of the number of typecodes assigned - _ufl_num_typecodes_ = 1 - - # A global set of all handler names added - _ufl_all_handler_names_ = set() - - # A global array of all Expr subclasses, indexed by typecode - _ufl_all_classes_ = [] - # A global dict mapping language_operator_name to the type it # produces _ufl_language_operators_ = {} @@ -247,14 +227,6 @@ def __init__(self): # --- Mechanism for profiling object creation and deletion --- - # A global array of the number of initialized objects for each - # typecode - _ufl_obj_init_counts_ = [0] - - # A global array of the number of deleted objects for each - # typecode - _ufl_obj_del_counts_ = [0] - # Backup of default init and del _ufl_regular__init__ = __init__ @@ -424,8 +396,10 @@ def __round__(self, n=None): # Initializing traits here because Expr is not defined in the class # declaration Expr._ufl_class_ = Expr -Expr._ufl_all_handler_names_.add(Expr) -Expr._ufl_all_classes_.append(Expr) + +# Update Expr with metaclass properties (e.g. typecode or handler name) +# Explicitly done here instead of using `@ufl_type` to avoid circular imports. +update_ufl_type_attributes(Expr) def ufl_err_str(expr): diff --git a/ufl/core/terminal.py b/ufl/core/terminal.py index 5b0840268..0f81e6bd1 100644 --- a/ufl/core/terminal.py +++ b/ufl/core/terminal.py @@ -99,7 +99,7 @@ def __eq__(self, other): @ufl_type(is_abstract=True) class FormArgument(Terminal): - "An abstract class for a form argument." + "An abstract class for a form argument (a thing in a primal finite element space)." __slots__ = () def __init__(self): diff --git a/ufl/core/ufl_type.py b/ufl/core/ufl_type.py index 9c4468a6f..e6435684c 100644 --- a/ufl/core/ufl_type.py +++ b/ufl/core/ufl_type.py @@ -8,9 +8,10 @@ # # Modified by Massimiliano Leoni, 2016 -from ufl.core.expr import Expr from ufl.core.compute_expr_hash import compute_expr_hash from ufl.utils.formatting import camel2underscore +# Avoid circular import +import ufl.core as core # Make UFL type coercion available under the as_ufl name @@ -98,12 +99,12 @@ def check_is_terminal_consistency(cls): def check_abstract_trait_consistency(cls): "Check that the first base classes up to ``Expr`` are other UFL types." for base in cls.mro(): - if base is Expr: + if base is core.expr.Expr: break - if not issubclass(base, Expr) and base._ufl_is_abstract_: + if not issubclass(base, core.expr.Expr) and base._ufl_is_abstract_: msg = ("Base class {0.__name__} of class {1.__name__} " "is not an abstract subclass of {2.__name__}.") - raise TypeError(msg.format(base, cls, Expr)) + raise TypeError(msg.format(base, cls, core.expr.Expr)) def check_has_slots(cls): @@ -129,6 +130,7 @@ def check_type_traits_consistency(cls): "Execute a variety of consistency checks on the ufl type traits." # Check for consistency in global type collection sizes + Expr = core.expr.Expr assert Expr._ufl_num_typecodes_ == len(Expr._ufl_all_handler_names_) assert Expr._ufl_num_typecodes_ == len(Expr._ufl_all_classes_) assert Expr._ufl_num_typecodes_ == len(Expr._ufl_obj_init_counts_) @@ -161,7 +163,7 @@ def check_type_traits_consistency(cls): def check_implements_required_methods(cls): """Check if type implements the required methods.""" if not cls._ufl_is_abstract_: - for attr in Expr._ufl_required_methods_: + for attr in core.expr.Expr._ufl_required_methods_: if not hasattr(cls, attr): msg = "Class {0.__name__} has no {1} method." raise TypeError(msg.format(cls, attr)) @@ -173,7 +175,7 @@ def check_implements_required_methods(cls): def check_implements_required_properties(cls): "Check if type implements the required properties." if not cls._ufl_is_abstract_: - for attr in Expr._ufl_required_properties_: + for attr in core.expr.Expr._ufl_required_properties_: if not hasattr(cls, attr): msg = "Class {0.__name__} has no {1} property." raise TypeError(msg.format(cls, attr)) @@ -214,11 +216,8 @@ def _inherited_ufl_index_dimensions(self): def update_global_expr_attributes(cls): "Update global ``Expr`` attributes, mainly by adding *cls* to global collections of ufl types." - Expr._ufl_all_classes_.append(cls) - Expr._ufl_all_handler_names_.add(cls._ufl_handler_name_) - if cls._ufl_is_terminal_modifier_: - Expr._ufl_terminal_modifiers_.append(cls) + core.expr.Expr._ufl_terminal_modifiers_.append(cls) # Add to collection of language operators. This collection is # used later to populate the official language namespace. @@ -226,12 +225,24 @@ def update_global_expr_attributes(cls): # it out later. if not cls._ufl_is_abstract_ and hasattr(cls, "_ufl_function_"): cls._ufl_function_.__func__.__doc__ = cls.__doc__ - Expr._ufl_language_operators_[cls._ufl_handler_name_] = cls._ufl_function_ + core.expr.Expr._ufl_language_operators_[cls._ufl_handler_name_] = cls._ufl_function_ + + +def update_ufl_type_attributes(cls): + # Determine integer typecode by incrementally counting all types + cls._ufl_typecode_ = UFLType._ufl_num_typecodes_ + UFLType._ufl_num_typecodes_ += 1 + + UFLType._ufl_all_classes_.append(cls) + + # Determine handler name by a mapping from "TypeName" to "type_name" + cls._ufl_handler_name_ = camel2underscore(cls.__name__) + UFLType._ufl_all_handler_names_.add(cls._ufl_handler_name_) # Append space for counting object creation and destriction of # this this type. - Expr._ufl_obj_init_counts_.append(0) - Expr._ufl_obj_del_counts_.append(0) + UFLType._ufl_obj_init_counts_.append(0) + UFLType._ufl_obj_del_counts_.append(0) def ufl_type(is_abstract=False, @@ -253,7 +264,7 @@ def ufl_type(is_abstract=False, unop=None, binop=None, rbinop=None): - """This decorator is to be applied to every subclass in the UFL ``Expr`` hierarchy. + """This decorator is to be applied to every subclass in the UFL ``Expr`` and ``BaseForm`` hierarchy. This decorator contains a number of checks that are intended to enforce uniform behaviour across UFL types. @@ -264,14 +275,14 @@ def ufl_type(is_abstract=False, """ def _ufl_type_decorator_(cls): - # Determine integer typecode by oncrementally counting all types - typecode = Expr._ufl_num_typecodes_ - Expr._ufl_num_typecodes_ += 1 - # Determine handler name by a mapping from "TypeName" to "type_name" - handler_name = camel2underscore(cls.__name__) + # Update attributes for UFLType instances (BaseForm and Expr objects) + update_ufl_type_attributes(cls) + if not issubclass(cls, core.expr.Expr): + # Don't need anything else for non Expr subclasses + return cls - # is_scalar implies is_index_free + # is_scalar implies is_index_freeg if is_scalar: _is_index_free = True else: @@ -279,8 +290,6 @@ def _ufl_type_decorator_(cls): # Store type traits cls._ufl_class_ = cls - set_trait(cls, "handler_name", handler_name, inherit=False) - set_trait(cls, "typecode", typecode, inherit=False) set_trait(cls, "is_abstract", is_abstract, inherit=False) set_trait(cls, "is_terminal", is_terminal, inherit=True) @@ -372,3 +381,37 @@ def _ufl_expr_rbinop_(self, other): return cls return _ufl_type_decorator_ + + +class UFLType(type): + """Base class for all UFL types. + + Equip UFL types with some ufl specific properties. + """ + + # A global counter of the number of typecodes assigned. + _ufl_num_typecodes_ = 0 + + # Set the handler name for UFLType + _ufl_handler_name_ = "ufl_type" + + # A global array of all Expr and BaseForm subclasses, indexed by typecode + _ufl_all_classes_ = [] + + # A global set of all handler names added + _ufl_all_handler_names_ = set() + + # A global array of the number of initialized objects for each + # typecode + _ufl_obj_init_counts_ = [] + + # A global array of the number of deleted objects for each + # typecode + _ufl_obj_del_counts_ = [] + + # Type trait: If the type is abstract. An abstract class cannot + # be instantiated and does not need all properties specified. + _ufl_is_abstract_ = True + + # Type trait: If the type is terminal. + _ufl_is_terminal_ = None diff --git a/ufl/corealg/multifunction.py b/ufl/corealg/multifunction.py index 16b08a3ba..e4d9f6ca6 100644 --- a/ufl/corealg/multifunction.py +++ b/ufl/corealg/multifunction.py @@ -13,6 +13,7 @@ from ufl.log import error from ufl.core.expr import Expr +from ufl.core.ufl_type import UFLType def get_num_args(function): @@ -66,7 +67,14 @@ def __init__(self): for c in classobject.mro(): # Register classobject with handler for the first # encountered superclass - handler_name = c._ufl_handler_name_ + try: + handler_name = c._ufl_handler_name_ + except AttributeError as attribute_error: + if type(classobject) is not UFLType: + raise attribute_error + # Default handler name for UFL types + handler_name = UFLType._ufl_handler_name_ + if hasattr(self, handler_name): handler_names[classobject._ufl_typecode_] = handler_name break @@ -107,5 +115,5 @@ def reuse_if_untouched(self, o, *ops): else: return o._ufl_expr_reconstruct_(*ops) - # Set default behaviour for any Expr as undefined - expr = undefined + # Set default behaviour for any UFLType as undefined + ufl_type = undefined diff --git a/ufl/differentiation.py b/ufl/differentiation.py index 049dd2685..8f678a9d3 100644 --- a/ufl/differentiation.py +++ b/ufl/differentiation.py @@ -15,6 +15,7 @@ from ufl.core.ufl_type import ufl_type from ufl.domain import extract_unique_domain, find_geometric_dimension from ufl.exprcontainers import ExprList, ExprMapping +from ufl.form import BaseForm from ufl.log import error from ufl.precedence import parstr from ufl.variable import Variable @@ -76,6 +77,25 @@ def __str__(self): self.ufl_operands[2], self.ufl_operands[3]) +@ufl_type(num_ops=4, inherit_shape_from_operand=0, + inherit_indices_from_operand=0) +class BaseFormDerivative(CoefficientDerivative, BaseForm): + """Derivative of a base form w.r.t the + degrees of freedom in a discrete Coefficient.""" + _ufl_noslots_ = True + + def __init__(self, base_form, coefficients, arguments, + coefficient_derivatives): + CoefficientDerivative.__init__(self, base_form, coefficients, arguments, + coefficient_derivatives) + BaseForm.__init__(self) + + def _analyze_form_arguments(self): + """Collect the arguments of the corresponding BaseForm""" + base_form = self.ufl_operands[0] + self._arguments = base_form.arguments() + + @ufl_type(num_ops=2) class VariableDerivative(Derivative): __slots__ = ( diff --git a/ufl/duals.py b/ufl/duals.py new file mode 100644 index 000000000..c0f1b15dc --- /dev/null +++ b/ufl/duals.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +"""Predicates for recognising duals""" + +# Copyright (C) 2021 India Marsden +# +# This file is part of UFL (https://www.fenicsproject.org) +# +# SPDX-License-Identifier: LGPL-3.0-or-later +# + + +def is_primal(object): + """Determine if the object belongs to a primal space + + This is not simply the negation of :func:`is_dual`, + because a mixed function space containing both primal + and dual components is neither primal nor dual.""" + return hasattr(object, '_primal') and object._primal + + +def is_dual(object): + """Determine if the object belongs to a dual space + + This is not simply the negation of :func:`is_primal`, + because a mixed function space containing both primal + and dual components is neither primal nor dual.""" + return hasattr(object, '_dual') and object._dual diff --git a/ufl/exprcontainers.py b/ufl/exprcontainers.py index 119321b01..c50050953 100644 --- a/ufl/exprcontainers.py +++ b/ufl/exprcontainers.py @@ -11,6 +11,8 @@ from ufl.core.expr import Expr from ufl.core.operator import Operator from ufl.core.ufl_type import ufl_type +from ufl.coefficient import Cofunction +from ufl.argument import Coargument # --- Non-tensor types --- @@ -22,8 +24,9 @@ class ExprList(Operator): def __init__(self, *operands): Operator.__init__(self, operands) - if not all(isinstance(i, Expr) for i in operands): - error("Expecting Expr in ExprList.") + # Enable Cofunction/Coargument for BaseForm differentiation + if not all(isinstance(i, (Expr, Cofunction, Coargument)) for i in operands): + error("Expecting Expr, Cofunction or Coargument in ExprList.") def __getitem__(self, i): return self.ufl_operands[i] diff --git a/ufl/form.py b/ufl/form.py index 838a93a1b..a1959de5f 100644 --- a/ufl/form.py +++ b/ufl/form.py @@ -20,12 +20,12 @@ from ufl.integral import Integral from ufl.checks import is_scalar_constant_expression from ufl.equation import Equation -from ufl.core.expr import Expr -from ufl.core.expr import ufl_err_str +from ufl.core.expr import Expr, ufl_err_str +from ufl.core.ufl_type import UFLType, ufl_type from ufl.constantvalue import Zero # Export list for ufl.classes -__all_classes__ = ["Form"] +__all_classes__ = ["Form", "BaseForm", "ZeroBaseForm"] # --- The Form class, representing a complete variational form or functional --- @@ -70,7 +70,189 @@ def _sorted_integrals(integrals): return tuple(all_integrals) # integrals_dict -class Form(object): +@ufl_type() +class BaseForm(object, metaclass=UFLType): + """Description of an object containing arguments""" + + # Slots is kept empty to enable multiple inheritance with other classes. + __slots__ = () + _ufl_is_abstract_ = True + _ufl_required_methods_ = ('_analyze_form_arguments', "ufl_domains") + + def __init__(self): + # Internal variables for caching form argument data + self._arguments = None + + # --- Accessor interface --- + def arguments(self): + "Return all ``Argument`` objects found in form." + if self._arguments is None: + self._analyze_form_arguments() + return self._arguments + + # --- Operator implementations --- + + def __eq__(self, other): + """Delayed evaluation of the == operator! + + Just 'lhs_form == rhs_form' gives an Equation, + while 'bool(lhs_form == rhs_form)' delegates + to lhs_form.equals(rhs_form). + """ + return Equation(self, other) + + def __radd__(self, other): + # Ordering of form additions make no difference + return self.__add__(other) + + def __add__(self, other): + if isinstance(other, (int, float)) and other == 0: + # Allow adding 0 or 0.0 as a no-op, needed for sum([a,b]) + return self + + elif isinstance( + other, + Zero) and not (other.ufl_shape or other.ufl_free_indices): + # Allow adding ufl Zero as a no-op, needed for sum([a,b]) + return self + + elif isinstance(other, ZeroBaseForm): + self._check_arguments_sum(other) + # Simplify addition with ZeroBaseForm + return self + + # For `ZeroBaseForm(...) + B` with B a BaseForm. + # We could overwrite ZeroBaseForm.__add__ but that implies + # duplicating cases with `0` and `ufl.Zero`. + elif isinstance(self, ZeroBaseForm): + self._check_arguments_sum(other) + # Simplify addition with ZeroBaseForm + return other + + elif isinstance(other, BaseForm): + # Add integrals from both forms + return FormSum((self, 1), (other, 1)) + + else: + # Let python protocols do their job if we don't handle it + return NotImplemented + + def _check_arguments_sum(self, other): + # Get component with the highest number of arguments + a = max((self, other), key=lambda x: len(x.arguments())) + b = self if a is other else other + # Components don't necessarily have the exact same arguments + # but the first argument(s) need to match as for `a + L` + # where a and L are a bilinear and linear form respectively. + a_args = sorted(a.arguments(), key=lambda x: x.number()) + b_args = sorted(b.arguments(), key=lambda x: x.number()) + if b_args != a_args[:len(b_args)]: + raise ValueError('Mismatching arguments when summing:\n %s\n and\n %s' % (self, other)) + + def __sub__(self, other): + "Subtract other form from this one." + return self + (-other) + + def __rsub__(self, other): + "Subtract this form from other." + return other + (-self) + + def __neg__(self): + """Negate all integrals in form. + + This enables the handy "-form" syntax for e.g. the + linearized system (J, -F) from a nonlinear form F.""" + if isinstance(self, ZeroBaseForm): + # `-` doesn't change anything for ZeroBaseForm. + # This also facilitates simplifying FormSum containing ZeroBaseForm objects. + return self + return FormSum((self, -1)) + + def __rmul__(self, scalar): + "Multiply all integrals in form with constant scalar value." + # This enables the handy "0*form" or "dt*form" syntax + if is_scalar_constant_expression(scalar): + return FormSum((self, scalar)) + return NotImplemented + + def __mul__(self, coefficient): + "Take the action of this form on the given coefficient." + if isinstance(coefficient, Expr): + from ufl.formoperators import action + return action(self, coefficient) + return NotImplemented + + def __ne__(self, other): + "Immediately evaluate the != operator (as opposed to the == operator)." + return not self.equals(other) + + def __call__(self, *args, **kwargs): + """Evaluate form by replacing arguments and coefficients. + + Replaces form.arguments() with given positional arguments in + same number and ordering. Number of positional arguments must + be 0 or equal to the number of Arguments in the form. + + The optional keyword argument coefficients can be set to a dict + to replace Coefficients with expressions of matching shapes. + + Example: + ------- + V = FiniteElement("CG", triangle, 1) + v = TestFunction(V) + u = TrialFunction(V) + f = Coefficient(V) + g = Coefficient(V) + a = g*inner(grad(u), grad(v))*dx + M = a(f, f, coefficients={ g: 1 }) + + Is equivalent to M == grad(f)**2*dx. + + """ + repdict = {} + + if args: + arguments = self.arguments() + if len(arguments) != len(args): + error("Need %d arguments to form(), got %d." % (len(arguments), + len(args))) + repdict.update(zip(arguments, args)) + + coefficients = kwargs.pop("coefficients") + if kwargs: + error("Unknown kwargs %s." % str(list(kwargs))) + + if coefficients is not None: + coeffs = self.coefficients() + for f in coefficients: + if f in coeffs: + repdict[f] = coefficients[f] + else: + warnings("Coefficient %s is not in form." % ufl_err_str(f)) + if repdict: + from ufl.formoperators import replace + return replace(self, repdict) + else: + return self + + def _ufl_compute_hash_(self): + "Compute the hash" + # Ensure compatibility with MultiFunction + # `hash(self)` will call the `__hash__` method of the subclass. + return hash(self) + + def _ufl_expr_reconstruct_(self, *operands): + "Return a new object of the same type with new operands." + return type(self)(*operands) + + # "a @ f" notation in python 3.5 + __matmul__ = __mul__ + + # --- String conversion functions, for UI purposes only --- + + +@ufl_type() +class Form(BaseForm): """Description of a weak form consisting of a sum of integrals over subdomains.""" __slots__ = ( # --- List of Integral objects (a Form is a sum of these Integrals, everything else is derived) @@ -93,6 +275,7 @@ class Form(object): ) def __init__(self, integrals): + BaseForm.__init__(self) # Basic input checking (further compatibilty analysis happens # later) if not all(isinstance(itg, Integral) for itg in integrals): @@ -110,7 +293,6 @@ def __init__(self, integrals): self._subdomain_data = None # Internal variables for caching form argument data - self._arguments = None self._coefficients = None self._coefficient_numbering = None self._constant_numbering = None @@ -261,15 +443,6 @@ def __hash__(self): self._hash = hash(tuple(hash(itg) for itg in self.integrals())) return self._hash - def __eq__(self, other): - """Delayed evaluation of the == operator! - - Just 'lhs_form == rhs_form' gives an Equation, - while 'bool(lhs_form == rhs_form)' delegates - to lhs_form.equals(rhs_form). - """ - return Equation(self, other) - def __ne__(self, other): "Immediate evaluation of the != operator (as opposed to the == operator)." return not self.equals(other) @@ -293,6 +466,15 @@ def __add__(self, other): # Add integrals from both forms return Form(list(chain(self.integrals(), other.integrals()))) + if isinstance(other, ZeroBaseForm): + self._check_arguments_sum(other) + # Simplify addition with ZeroBaseForm + return self + + elif isinstance(other, BaseForm): + # Create form sum if form is of other type + return FormSum((self, 1), (other, 1)) + elif isinstance(other, (int, float)) and other == 0: # Allow adding 0 or 0.0 as a no-op, needed for sum([a,b]) return self @@ -512,7 +694,7 @@ def sub_forms_by_domain(form): def as_form(form): "Convert to form if not a form, otherwise return form." - if not isinstance(form, Form): + if not isinstance(form, BaseForm): error("Unable to convert object to a UFL form: %s" % ufl_err_str(form)) return form @@ -547,3 +729,155 @@ def replace_integral_domains(form, common_domain): # TODO: Move elsewhere if reconstruct: form = Form(integrals) return form + + +@ufl_type() +class FormSum(BaseForm): + """Description of a weighted sum of variational forms and form-like objects + components is the list of Forms to be summed + arg_weights is a list of tuples of component index and weight""" + + __slots__ = ("_arguments", + "_weights", + "_components", + "ufl_operands", + "_domains", + "_domain_numbering", + "_hash") + _ufl_required_methods_ = ('_analyze_form_arguments') + + def __init__(self, *components): + BaseForm.__init__(self) + + weights = [] + full_components = [] + for (component, w) in components: + if isinstance(component, FormSum): + full_components.extend(component.components()) + weights.extend(w * component.weights()) + else: + full_components.append(component) + weights.append(w) + + self._arguments = None + self._domains = None + self._domain_numbering = None + self._hash = None + self._weights = weights + self._components = full_components + self._sum_variational_components() + self.ufl_operands = self._components + + def components(self): + return self._components + + def weights(self): + return self._weights + + def _sum_variational_components(self): + var_forms = None + other_components = [] + new_weights = [] + for (i, component) in enumerate(self._components): + if isinstance(component, Form): + if var_forms: + var_forms = var_forms + (self._weights[i] * component) + else: + var_forms = self._weights[i] * component + else: + other_components.append(component) + new_weights.append(self._weights[i]) + if var_forms: + other_components.insert(0, var_forms) + new_weights.insert(0, 1) + self._components = other_components + self._weights = new_weights + + def _analyze_form_arguments(self): + "Return all ``Argument`` objects found in form." + arguments = [] + for component in self._components: + arguments.extend(component.arguments()) + self._arguments = tuple(set(arguments)) + + def __hash__(self): + "Hash code for use in dicts (includes incidental numbering of indices etc.)" + if self._hash is None: + self._hash = hash(tuple(hash(component) for component in self.components())) + return self._hash + + def equals(self, other): + "Evaluate ``bool(lhs_form == rhs_form)``." + if type(other) != FormSum: + return False + if self is other: + return True + return (len(self.components()) == len(other.components()) and + all(a == b for a, b in zip(self.components(), other.components()))) + + def __str__(self): + "Compute shorter string representation of form. This can be huge for complicated forms." + # Warning used for making sure we don't use this in the general pipeline: + # warning("Calling str on form is potentially expensive and should be avoided except during debugging.") + # Not caching this because it can be huge + s = "\n + ".join(str(component) for component in self.components()) + return s or "" + + def __repr__(self): + "Compute repr string of form. This can be huge for complicated forms." + # Warning used for making sure we don't use this in the general pipeline: + # warning("Calling repr on form is potentially expensive and should be avoided except during debugging.") + # Not caching this because it can be huge + itgs = ", ".join(repr(component) for component in self.components()) + r = "FormSum([" + itgs + "])" + return r + + +@ufl_type() +class ZeroBaseForm(BaseForm): + """Description of a zero base form. + ZeroBaseForm is idempotent with respect to assembly and is mostly used for sake of simplifying base-form expressions. + """ + + __slots__ = ("_arguments", + "_coefficients", + "ufl_operands", + "_hash", + # Pyadjoint compatibility + "form") + + def __init__(self, arguments): + BaseForm.__init__(self) + self._arguments = arguments + self.ufl_operands = arguments + self._hash = None + self.form = None + + def _analyze_form_arguments(self): + return self._arguments + + def __ne__(self, other): + # Overwrite BaseForm.__neq__ which relies on `equals` + return not self == other + + def __eq__(self, other): + if type(other) is ZeroBaseForm: + if self is other: + return True + return (self._arguments == other._arguments) + elif isinstance(other, (int, float)): + return other == 0 + else: + return False + + def __str__(self): + return "ZeroBaseForm(%s)" % (", ".join(str(arg) for arg in self._arguments)) + + def __repr__(self): + return "ZeroBaseForm(%s)" % (", ".join(repr(arg) for arg in self._arguments)) + + def __hash__(self): + """Hash code for use in dicts.""" + if self._hash is None: + self._hash = hash(("ZeroBaseForm", hash(self._arguments))) + return self._hash diff --git a/ufl/formoperators.py b/ufl/formoperators.py index 4fbfd7999..4a3f648ea 100644 --- a/ufl/formoperators.py +++ b/ufl/formoperators.py @@ -12,15 +12,17 @@ # Modified by Cecile Daversin-Catty, 2018 from ufl.log import error -from ufl.form import Form, as_form +from ufl.form import Form, FormSum, BaseForm, as_form from ufl.core.expr import Expr, ufl_err_str from ufl.split_functions import split from ufl.exprcontainers import ExprList, ExprMapping from ufl.variable import Variable from ufl.finiteelement import MixedElement from ufl.argument import Argument -from ufl.coefficient import Coefficient -from ufl.differentiation import CoefficientDerivative, CoordinateDerivative +from ufl.coefficient import Coefficient, Cofunction +from ufl.adjoint import Adjoint +from ufl.action import Action +from ufl.differentiation import CoefficientDerivative, BaseFormDerivative, CoordinateDerivative from ufl.constantvalue import is_true_ufl_scalar, as_ufl from ufl.indexed import Indexed from ufl.core.multiindex import FixedIndex, MultiIndex @@ -104,10 +106,14 @@ def action(form, coefficient=None): Given a bilinear form, return a linear form with an additional coefficient, representing the action of the form on the coefficient. This can be - used for matrix-free methods.""" + used for matrix-free methods. + For formbase objects,coefficient can be any object of the correct type, + and this function returns an Action object.""" form = as_form(form) - form = expand_derivatives(form) - return compute_form_action(form, coefficient) + if isinstance(form, Form) and not (isinstance(coefficient, BaseForm) and len(coefficient.arguments()) > 1): + form = expand_derivatives(form) + return compute_form_action(form, coefficient) + return Action(form, coefficient) def energy_norm(form, coefficient=None): @@ -129,10 +135,15 @@ def adjoint(form, reordered_arguments=None): opposite ordering. However, if the adjoint form is to be added to other forms later, their arguments must match. In that case, the user must provide a tuple *reordered_arguments*=(u2,v2). + + If the form is a baseform instance instead of a Form object, we return an Adjoint + object instructing the adjoint to be computed at a later point. """ form = as_form(form) - form = expand_derivatives(form) - return compute_form_adjoint(form, reordered_arguments) + if isinstance(form, Form): + form = expand_derivatives(form) + return compute_form_adjoint(form, reordered_arguments) + return Adjoint(form) def zero_lists(shape): @@ -162,7 +173,7 @@ def _handle_derivative_arguments(form, coefficient, argument): if argument is None: # Try to create argument if not provided - if not all(isinstance(c, Coefficient) for c in coefficients): + if not all(isinstance(c, (Coefficient, Cofunction)) for c in coefficients): error("Can only create arguments automatically for non-indexed coefficients.") # Get existing arguments from form and position the new one @@ -215,7 +226,7 @@ def _handle_derivative_arguments(form, coefficient, argument): for (c, a) in zip(coefficients, arguments): if c.ufl_shape != a.ufl_shape: error("Coefficient and argument shapes do not match!") - if isinstance(c, Coefficient) or isinstance(c, SpatialCoordinate): + if isinstance(c, (Coefficient, Cofunction, SpatialCoordinate)): m[c] = a else: if not isinstance(c, Indexed): @@ -269,9 +280,23 @@ def derivative(form, coefficient, argument=None, coefficient_derivatives=None): ``Coefficient`` instances to their derivatives w.r.t. *coefficient*. """ + if isinstance(form, FormSum): + # Distribute derivative over FormSum components + return FormSum(*[(derivative(component, coefficient, argument, coefficient_derivatives), 1) + for component in form.components()]) + elif isinstance(form, Adjoint): + # Push derivative through Adjoint + return adjoint(derivative(form._form, coefficient, argument, coefficient_derivatives)) + elif isinstance(form, Action): + # Push derivative through Action slots + left, right = form.ufl_operands + dleft = derivative(left, coefficient, argument, coefficient_derivatives) + dright = derivative(right, coefficient, argument, coefficient_derivatives) + # Leibniz formula + return action(dleft, right) + action(left, dright) + coefficients, arguments = _handle_derivative_arguments(form, coefficient, argument) - if coefficient_derivatives is None: coefficient_derivatives = ExprMapping() else: @@ -293,6 +318,10 @@ def derivative(form, coefficient, argument=None, coefficient_derivatives=None): integrals.append(itg.reconstruct(fd)) return Form(integrals) + elif isinstance(form, BaseForm): + if not isinstance(coefficient, SpatialCoordinate): + return BaseFormDerivative(form, coefficients, arguments, coefficient_derivatives) + elif isinstance(form, Expr): # What we got was in fact an integrand if not isinstance(coefficient, SpatialCoordinate): diff --git a/ufl/functionspace.py b/ufl/functionspace.py index 2758b2e7d..6842d9ca2 100644 --- a/ufl/functionspace.py +++ b/ufl/functionspace.py @@ -13,11 +13,13 @@ from ufl.log import error from ufl.core.ufl_type import attach_operators_from_hash_data from ufl.domain import join_domains +from ufl.duals import is_dual, is_primal # Export list for ufl.classes __all_classes__ = [ "AbstractFunctionSpace", "FunctionSpace", + "DualSpace", "MixedFunctionSpace", "TensorProductFunctionSpace", ] @@ -25,11 +27,14 @@ class AbstractFunctionSpace(object): def ufl_sub_spaces(self): - raise NotImplementedError("Missing implementation of IFunctionSpace.ufl_sub_spaces in %s." % self.__class__.__name__) + raise NotImplementedError( + "Missing implementation of IFunctionSpace.ufl_sub_spaces in %s." + % self.__class__.__name__ + ) @attach_operators_from_hash_data -class FunctionSpace(AbstractFunctionSpace): +class BaseFunctionSpace(AbstractFunctionSpace): def __init__(self, domain, element): if domain is None: # DOLFIN hack @@ -39,7 +44,8 @@ def __init__(self, domain, element): try: domain_cell = domain.ufl_cell() except AttributeError: - error("Expected non-abstract domain for initalization of function space.") + error("Expected non-abstract domain for initalization " + "of function space.") else: if element.cell() != domain_cell: error("Non-matching cell of finite element and domain.") @@ -68,7 +74,8 @@ def ufl_domains(self): else: return (domain,) - def _ufl_hash_data_(self): + def _ufl_hash_data_(self, name=None): + name = name or "BaseFunctionSpace" domain = self.ufl_domain() element = self.ufl_element() if domain is None: @@ -79,9 +86,10 @@ def _ufl_hash_data_(self): edata = None else: edata = element._ufl_hash_data_() - return ("FunctionSpace", ddata, edata) + return (name, ddata, edata) - def _ufl_signature_data_(self, renumbering): + def _ufl_signature_data_(self, renumbering, name=None): + name = name or "BaseFunctionSpace" domain = self.ufl_domain() element = self.ufl_element() if domain is None: @@ -92,10 +100,56 @@ def _ufl_signature_data_(self, renumbering): edata = None else: edata = element._ufl_signature_data_() - return ("FunctionSpace", ddata, edata) + return (name, ddata, edata) + + def __repr__(self): + r = "BaseFunctionSpace(%s, %s)" % (repr(self._ufl_domain), + repr(self._ufl_element)) + return r + + +@attach_operators_from_hash_data +class FunctionSpace(BaseFunctionSpace): + """Representation of a Function space.""" + _primal = True + _dual = False + + def dual(self): + return DualSpace(self._ufl_domain, self._ufl_element) + + def _ufl_hash_data_(self): + return BaseFunctionSpace._ufl_hash_data_(self, "FunctionSpace") + + def _ufl_signature_data_(self, renumbering): + return BaseFunctionSpace._ufl_signature_data_(self, renumbering, "FunctionSpace") + + def __repr__(self): + r = "FunctionSpace(%s, %s)" % (repr(self._ufl_domain), + repr(self._ufl_element)) + return r + + +@attach_operators_from_hash_data +class DualSpace(BaseFunctionSpace): + """Representation of a Dual space.""" + _primal = False + _dual = True + + def __init__(self, domain, element): + BaseFunctionSpace.__init__(self, domain, element) + + def dual(self): + return FunctionSpace(self._ufl_domain, self._ufl_element) + + def _ufl_hash_data_(self): + return BaseFunctionSpace._ufl_hash_data_(self, "DualSpace") + + def _ufl_signature_data_(self, renumbering): + return BaseFunctionSpace._ufl_signature_data_(self, renumbering, "DualSpace") def __repr__(self): - r = "FunctionSpace(%s, %s)" % (repr(self._ufl_domain), repr(self._ufl_element)) + r = "DualSpace(%s, %s)" % (repr(self._ufl_domain), + repr(self._ufl_element)) return r @@ -109,10 +163,13 @@ def ufl_sub_spaces(self): return self._ufl_function_spaces def _ufl_hash_data_(self): - return ("TensorProductFunctionSpace",) + tuple(V._ufl_hash_data_() for V in self.ufl_sub_spaces()) + return ("TensorProductFunctionSpace",) \ + + tuple(V._ufl_hash_data_() for V in self.ufl_sub_spaces()) def _ufl_signature_data_(self, renumbering): - return ("TensorProductFunctionSpace",) + tuple(V._ufl_signature_data_(renumbering) for V in self.ufl_sub_spaces()) + return ("TensorProductFunctionSpace",) \ + + tuple(V._ufl_signature_data_(renumbering) + for V in self.ufl_sub_spaces()) def __repr__(self): r = "TensorProductFunctionSpace(*%s)" % repr(self._ufl_function_spaces) @@ -121,15 +178,22 @@ def __repr__(self): @attach_operators_from_hash_data class MixedFunctionSpace(AbstractFunctionSpace): + def __init__(self, *args): AbstractFunctionSpace.__init__(self) self._ufl_function_spaces = args self._ufl_elements = list() for fs in args: - if isinstance(fs, FunctionSpace): + if isinstance(fs, BaseFunctionSpace): self._ufl_elements.append(fs.ufl_element()) else: - error("Expecting FunctionSpace objects") + error("Expecting BaseFunctionSpace objects") + + # A mixed FS is only primal/dual if all the subspaces are primal/dual" + self._primal = all([is_primal(subspace) + for subspace in self._ufl_function_spaces]) + self._dual = all([is_dual(subspace) + for subspace in self._ufl_function_spaces]) def ufl_sub_spaces(self): "Return ufl sub spaces." @@ -139,6 +203,25 @@ def ufl_sub_space(self, i): "Return i-th ufl sub space." return self._ufl_function_spaces[i] + def dual(self, *args): + """Return the dual to this function space. + + If no additional arguments are passed then a MixedFunctionSpace is + returned whose components are the duals of the originals. + + If additional arguments are passed, these must be integers. In this + case, the MixedFunctionSpace which is returned will have dual + components in the positions corresponding to the arguments passed, and + the original components in the other positions.""" + if args: + spaces = [space.dual() if i in args else space + for i, space in enumerate(self._ufl_function_spaces)] + return MixedFunctionSpace(*spaces) + else: + return MixedFunctionSpace( + *[space.dual()for space in self._ufl_function_spaces] + ) + def ufl_elements(self): "Return ufl elements." return self._ufl_elements @@ -172,10 +255,13 @@ def num_sub_spaces(self): return len(self._ufl_function_spaces) def _ufl_hash_data_(self): - return ("MixedFunctionSpace",) + tuple(V._ufl_hash_data_() for V in self.ufl_sub_spaces()) + return ("MixedFunctionSpace",) \ + + tuple(V._ufl_hash_data_() for V in self.ufl_sub_spaces()) def _ufl_signature_data_(self, renumbering): - return ("MixedFunctionSpace",) + tuple(V._ufl_signature_data_(renumbering) for V in self.ufl_sub_spaces()) + return ("MixedFunctionSpace",) \ + + tuple(V._ufl_signature_data_(renumbering) + for V in self.ufl_sub_spaces()) def __repr__(self): r = "MixedFunctionSpace(*%s)" % repr(self._ufl_function_spaces) diff --git a/ufl/matrix.py b/ufl/matrix.py new file mode 100644 index 000000000..a5523fe52 --- /dev/null +++ b/ufl/matrix.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- +"""This module defines the Matrix class.""" + +# Copyright (C) 2021 India Marsden +# +# This file is part of UFL (https://www.fenicsproject.org) +# +# SPDX-License-Identifier: LGPL-3.0-or-later +# +# Modified by Nacime Bouziani, 2021-2022. + +from ufl.log import error +from ufl.form import BaseForm +from ufl.core.ufl_type import ufl_type +from ufl.argument import Argument +from ufl.functionspace import AbstractFunctionSpace +from ufl.utils.counted import counted_init + + +# --- The Matrix class represents a matrix, an assembled two form --- + +@ufl_type() +class Matrix(BaseForm): + """An assemble linear operator between two function spaces.""" + + __slots__ = ( + "_count", + "_ufl_function_spaces", + "ufl_operands", + "_repr", + "_hash", + "_ufl_shape", + "_arguments") + _globalcount = 0 + + def __getnewargs__(self): + return (self._ufl_function_spaces[0], self._ufl_function_spaces[1], + self._count) + + def __init__(self, row_space, column_space, count=None): + BaseForm.__init__(self) + counted_init(self, count, Matrix) + + if not isinstance(row_space, AbstractFunctionSpace): + error("Expecting a FunctionSpace as the row space.") + + if not isinstance(column_space, AbstractFunctionSpace): + error("Expecting a FunctionSpace as the column space.") + + self._ufl_function_spaces = (row_space, column_space) + + self.ufl_operands = () + self._hash = None + self._repr = "Matrix(%s,%s, %s)" % ( + repr(self._ufl_function_spaces[0]), + repr(self._ufl_function_spaces[1]), repr(self._count) + ) + + def count(self): + return self._count + + def ufl_function_spaces(self): + "Get the tuple of function spaces of this coefficient." + return self._ufl_function_spaces + + def ufl_row_space(self): + return self._ufl_function_spaces[0] + + def ufl_column_space(self): + return self._ufl_function_spaces[1] + + def _analyze_form_arguments(self): + "Define arguments of a matrix when considered as a form." + self._arguments = (Argument(self._ufl_function_spaces[0], 0), + Argument(self._ufl_function_spaces[1], 1)) + + def __str__(self): + count = str(self._count) + if len(count) == 1: + return "A_%s" % count + else: + return "A_{%s}" % count + + def __repr__(self): + return self._repr + + def __hash__(self): + "Hash code for use in dicts " + if self._hash is None: + self._hash = hash(self._repr) + return self._hash + + def equals(self, other): + if type(other) is not Matrix: + return False + if self is other: + return True + return (self._count == other._count and + self._ufl_function_spaces == other._ufl_function_spaces)