Skip to content

Commit

Permalink
test_create_fi(n)at_element
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Dec 19, 2024
1 parent 406da7e commit bcc3ebe
Show file tree
Hide file tree
Showing 3 changed files with 288 additions and 3 deletions.
3 changes: 0 additions & 3 deletions FIAT/macro.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from itertools import chain, combinations

from functools import cache
import numpy

from FIAT import expansions, polynomial_set
Expand Down Expand Up @@ -196,7 +195,6 @@ def get_parent_complex(self):
return self._parent_complex


@cache
class IsoSplit(SplitSimplicialComplex):
"""Splits simplex into the simplicial complex obtained by
connecting points on a regular lattice.
Expand Down Expand Up @@ -302,7 +300,6 @@ def construct_subcomplex(self, dimension):
return PowellSabinSplit(subcomplex, dimension=self.split_dimension)


@cache
class AlfeldSplit(PowellSabinSplit):
"""Splits a simplicial complex by connecting cell vertices to their
barycenter.
Expand Down
150 changes: 150 additions & 0 deletions test/finat/test_create_fiat_element.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import pytest

import FIAT
from FIAT.discontinuous_lagrange import DiscontinuousLagrange as FIAT_DiscontinuousLagrange

import ufl
import finat.ufl
from finat.element_factory import create_element as _create_element


supported_elements = {
# These all map directly to FIAT elements
"Brezzi-Douglas-Marini": FIAT.BrezziDouglasMarini,
"Brezzi-Douglas-Fortin-Marini": FIAT.BrezziDouglasFortinMarini,
"Lagrange": FIAT.Lagrange,
"Nedelec 1st kind H(curl)": FIAT.Nedelec,
"Nedelec 2nd kind H(curl)": FIAT.NedelecSecondKind,
"Raviart-Thomas": FIAT.RaviartThomas,
"Regge": FIAT.Regge,
}
"""A :class:`.dict` mapping UFL element family names to their
FIAT-equivalent constructors."""


def create_element(ufl_element):
"""Create a FIAT element given a UFL element."""
finat_element = _create_element(ufl_element)
return finat_element.fiat_equivalent


@pytest.fixture(params=["BDM",
"BDFM",
"Lagrange",
"N1curl",
"N2curl",
"RT",
"Regge"])
def triangle_names(request):
return request.param


@pytest.fixture
def ufl_element(triangle_names):
return finat.ufl.FiniteElement(triangle_names, ufl.triangle, 2)


def test_triangle_basic(ufl_element):
element = create_element(ufl_element)
assert isinstance(element, supported_elements[ufl_element.family()])


@pytest.fixture(params=["CG", "DG", "DG L2"], scope="module")
def tensor_name(request):
return request.param


@pytest.fixture(params=[ufl.interval, ufl.triangle,
ufl.quadrilateral],
ids=lambda x: x.cellname(),
scope="module")
def ufl_A(request, tensor_name):
return finat.ufl.FiniteElement(tensor_name, request.param, 1)


@pytest.fixture
def ufl_B(tensor_name):
return finat.ufl.FiniteElement(tensor_name, ufl.interval, 1)


def test_tensor_prod_simple(ufl_A, ufl_B):
tensor_ufl = finat.ufl.TensorProductElement(ufl_A, ufl_B)

tensor = create_element(tensor_ufl)
A = create_element(ufl_A)
B = create_element(ufl_B)

assert isinstance(tensor, FIAT.TensorProductElement)

assert tensor.A is A
assert tensor.B is B


@pytest.mark.parametrize(('family', 'expected_cls'),
[('P', FIAT.GaussLobattoLegendre),
('DP', FIAT.GaussLegendre),
('DP L2', FIAT.GaussLegendre)])
def test_interval_variant_default(family, expected_cls):
ufl_element = finat.ufl.FiniteElement(family, ufl.interval, 3)
assert isinstance(create_element(ufl_element), expected_cls)


@pytest.mark.parametrize(('family', 'variant', 'expected_cls'),
[('P', 'equispaced', FIAT.Lagrange),
('P', 'spectral', FIAT.GaussLobattoLegendre),
('DP', 'equispaced', FIAT_DiscontinuousLagrange),
('DP', 'spectral', FIAT.GaussLegendre),
('DP L2', 'equispaced', FIAT_DiscontinuousLagrange),
('DP L2', 'spectral', FIAT.GaussLegendre)])
def test_interval_variant(family, variant, expected_cls):
ufl_element = finat.ufl.FiniteElement(family, ufl.interval, 3, variant=variant)
assert isinstance(create_element(ufl_element), expected_cls)


def test_triangle_variant_spectral():
ufl_element = finat.ufl.FiniteElement('DP', ufl.triangle, 2, variant='spectral')
create_element(ufl_element)


def test_triangle_variant_spectral_l2():
ufl_element = finat.ufl.FiniteElement('DP L2', ufl.triangle, 2, variant='spectral')
create_element(ufl_element)


def test_quadrilateral_variant_spectral_q():
element = create_element(finat.ufl.FiniteElement('Q', ufl.quadrilateral, 3, variant='spectral'))
assert isinstance(element.element.A, FIAT.GaussLobattoLegendre)
assert isinstance(element.element.B, FIAT.GaussLobattoLegendre)


def test_quadrilateral_variant_spectral_dq():
element = create_element(finat.ufl.FiniteElement('DQ', ufl.quadrilateral, 1, variant='spectral'))
assert isinstance(element.element.A, FIAT.GaussLegendre)
assert isinstance(element.element.B, FIAT.GaussLegendre)


def test_quadrilateral_variant_spectral_dq_l2():
element = create_element(finat.ufl.FiniteElement('DQ L2', ufl.quadrilateral, 1, variant='spectral'))
assert isinstance(element.element.A, FIAT.GaussLegendre)
assert isinstance(element.element.B, FIAT.GaussLegendre)


def test_quadrilateral_variant_spectral_rtcf():
element = create_element(finat.ufl.FiniteElement('RTCF', ufl.quadrilateral, 2, variant='spectral'))
assert isinstance(element.element._elements[0].A, FIAT.GaussLobattoLegendre)
assert isinstance(element.element._elements[0].B, FIAT.GaussLegendre)
assert isinstance(element.element._elements[1].A, FIAT.GaussLegendre)
assert isinstance(element.element._elements[1].B, FIAT.GaussLobattoLegendre)


def test_cache_hit(ufl_element):
A = create_element(ufl_element)
B = create_element(ufl_element)

assert A is B


if __name__ == "__main__":
import os
import sys
pytest.main(args=[os.path.abspath(__file__)] + sys.argv[1:])
138 changes: 138 additions & 0 deletions test/finat/test_create_finat_element.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import pytest

import ufl
import finat.ufl
import finat
from finat.element_factory import create_element, supported_elements


@pytest.fixture(params=["BDM",
"BDFM",
"Lagrange",
"N1curl",
"N2curl",
"RT",
"Regge"])
def triangle_names(request):
return request.param


@pytest.fixture
def ufl_element(triangle_names):
return finat.ufl.FiniteElement(triangle_names, ufl.triangle, 2)


def test_triangle_basic(ufl_element):
element = create_element(ufl_element)
assert isinstance(element, supported_elements[ufl_element.family()])


@pytest.fixture
def ufl_vector_element(triangle_names):
return finat.ufl.VectorElement(triangle_names, ufl.triangle, 2)


def test_triangle_vector(ufl_element, ufl_vector_element):
scalar = create_element(ufl_element)
vector = create_element(ufl_vector_element)

assert isinstance(vector, finat.TensorFiniteElement)
assert scalar == vector.base_element


@pytest.fixture(params=["CG", "DG", "DG L2"])
def tensor_name(request):
return request.param


@pytest.fixture(params=[ufl.interval, ufl.triangle,
ufl.quadrilateral],
ids=lambda x: x.cellname())
def ufl_A(request, tensor_name):
return finat.ufl.FiniteElement(tensor_name, request.param, 1)


@pytest.fixture
def ufl_B(tensor_name):
return finat.ufl.FiniteElement(tensor_name, ufl.interval, 1)


def test_tensor_prod_simple(ufl_A, ufl_B):
tensor_ufl = finat.ufl.TensorProductElement(ufl_A, ufl_B)

tensor = create_element(tensor_ufl)
A = create_element(ufl_A)
B = create_element(ufl_B)

assert isinstance(tensor, finat.TensorProductElement)

assert tensor.factors == (A, B)


@pytest.mark.parametrize(('family', 'expected_cls'),
[('P', finat.GaussLobattoLegendre),
('DP', finat.GaussLegendre),
('DP L2', finat.GaussLegendre)])
def test_interval_variant_default(family, expected_cls):
ufl_element = finat.ufl.FiniteElement(family, ufl.interval, 3)
assert isinstance(create_element(ufl_element), expected_cls)


@pytest.mark.parametrize(('family', 'variant', 'expected_cls'),
[('P', 'equispaced', finat.Lagrange),
('P', 'spectral', finat.GaussLobattoLegendre),
('DP', 'equispaced', finat.DiscontinuousLagrange),
('DP', 'spectral', finat.GaussLegendre),
('DP L2', 'equispaced', finat.DiscontinuousLagrange),
('DP L2', 'spectral', finat.GaussLegendre)])
def test_interval_variant(family, variant, expected_cls):
ufl_element = finat.ufl.FiniteElement(family, ufl.interval, 3, variant=variant)
assert isinstance(create_element(ufl_element), expected_cls)


def test_triangle_variant_spectral():
ufl_element = finat.ufl.FiniteElement('DP', ufl.triangle, 2, variant='spectral')
create_element(ufl_element)


def test_triangle_variant_spectral_l2():
ufl_element = finat.ufl.FiniteElement('DP L2', ufl.triangle, 2, variant='spectral')
create_element(ufl_element)


def test_quadrilateral_variant_spectral_q():
element = create_element(finat.ufl.FiniteElement('Q', ufl.quadrilateral, 3, variant='spectral'))
assert isinstance(element.product.factors[0], finat.GaussLobattoLegendre)
assert isinstance(element.product.factors[1], finat.GaussLobattoLegendre)


def test_quadrilateral_variant_spectral_dq():
element = create_element(finat.ufl.FiniteElement('DQ', ufl.quadrilateral, 1, variant='spectral'))
assert isinstance(element.product.factors[0], finat.GaussLegendre)
assert isinstance(element.product.factors[1], finat.GaussLegendre)


def test_quadrilateral_variant_spectral_dq_l2():
element = create_element(finat.ufl.FiniteElement('DQ L2', ufl.quadrilateral, 1, variant='spectral'))
assert isinstance(element.product.factors[0], finat.GaussLegendre)
assert isinstance(element.product.factors[1], finat.GaussLegendre)


def test_cache_hit(ufl_element):
A = create_element(ufl_element)
B = create_element(ufl_element)

assert A is B


def test_cache_hit_vector(ufl_vector_element):
A = create_element(ufl_vector_element)
B = create_element(ufl_vector_element)

assert A is B


if __name__ == "__main__":
import os
import sys
pytest.main(args=[os.path.abspath(__file__)] + sys.argv[1:])

0 comments on commit bcc3ebe

Please sign in to comment.