From be411c5572ff88526d2526b4b80027b4602672fb Mon Sep 17 00:00:00 2001 From: ksagiyam <46749170+ksagiyam@users.noreply.github.com> Date: Wed, 11 Dec 2024 16:19:58 +0000 Subject: [PATCH] enable restricted function space on extruded meshes (#3905) --- firedrake/bcs.py | 8 +- firedrake/cython/dmcommon.pyx | 112 +++++++++++---- firedrake/cython/petschdr.pxi | 1 + firedrake/functionspaceimpl.py | 14 +- firedrake/mesh.py | 4 +- .../static_condensation/hybridization.py | 8 +- pyop2/types/dat.py | 24 +++- pyop2/types/dataset.py | 18 ++- .../test_restricted_function_space.py | 131 ++++++++++++++++++ tests/pyop2/test_api.py | 4 +- 10 files changed, 270 insertions(+), 54 deletions(-) diff --git a/firedrake/bcs.py b/firedrake/bcs.py index f1568824bd..f0d007ede4 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -41,7 +41,7 @@ class BCBase(object): def __init__(self, V, sub_domain): self._function_space = V - self.sub_domain = sub_domain + self.sub_domain = (sub_domain, ) if isinstance(sub_domain, str) else as_tuple(sub_domain) # If this BC is defined on a subspace (IndexedFunctionSpace or # ComponentFunctionSpace, possibly recursively), pull out the appropriate # indices. @@ -289,11 +289,9 @@ def __init__(self, V, g, sub_domain, method=None): warnings.simplefilter('always', DeprecationWarning) warnings.warn("Selecting a bcs method is deprecated. Only topological association is supported", DeprecationWarning) - if len(V.boundary_set): - subs = [sub_domain] if type(sub_domain) in {int, str} else sub_domain - if any(sub not in V.boundary_set for sub in subs): - raise ValueError(f"Sub-domain {sub_domain} not in the boundary set of the restricted space.") super().__init__(V, sub_domain) + if len(V.boundary_set) and not set(self.sub_domain).issubset(V.boundary_set): + raise ValueError(f"Sub-domain {self.sub_domain} not in the boundary set of the restricted space {V.boundary_set}.") if len(V) > 1: raise ValueError("Cannot apply boundary conditions on mixed spaces directly.\n" "Apply to the components by indexing the space with .sub(...)") diff --git a/firedrake/cython/dmcommon.pyx b/firedrake/cython/dmcommon.pyx index de0c63986e..eda8470d4e 100644 --- a/firedrake/cython/dmcommon.pyx +++ b/firedrake/cython/dmcommon.pyx @@ -1205,7 +1205,7 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1, boundary PETSc.DM dm PETSc.Section section PETSc.IS renumbering - PetscInt i, p, layers, pStart, pEnd, dof, j + PetscInt i, p, layers, offset_top, pStart, pEnd, dof, j, k PetscInt dimension, ndof PetscInt *dof_array = NULL const PetscInt *entity_point_map @@ -1213,6 +1213,10 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1, boundary np.ndarray layer_extents np.ndarray points bint variable, extruded, on_base_ + PETSc.SF point_sf + PetscInt nleaves + const PetscInt *ilocal = NULL + PetscInt factor dm = mesh.topology_dm if isinstance(dm, PETSc.DMSwarm) and on_base: @@ -1221,32 +1225,31 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1, boundary extruded = mesh.cell_set._extruded extruded_periodic = mesh.cell_set._extruded_periodic on_base_ = on_base + dimension = get_topological_dimension(dm) nodes_per_entity = np.asarray(nodes_per_entity, dtype=IntType) if variable: layer_extents = mesh.layer_extents + nodes = nodes_per_entity.reshape(dimension + 1, -1) elif extruded: if on_base: - nodes_per_entity = sum(nodes_per_entity[:, i] for i in range(2)) + nodes = sum(nodes_per_entity[:, i] for i in range(2)).reshape(dimension + 1, -1) else: if extruded_periodic: - nodes_per_entity = sum(nodes_per_entity[:, i]*(mesh.layers - 1) for i in range(2)) + nodes = sum(nodes_per_entity[:, i]*(mesh.layers - 1) for i in range(2)).reshape(dimension + 1, -1) else: - nodes_per_entity = sum(nodes_per_entity[:, i]*(mesh.layers - i) for i in range(2)) + nodes = sum(nodes_per_entity[:, i]*(mesh.layers - i) for i in range(2)).reshape(dimension + 1, -1) + else: + nodes = nodes_per_entity.reshape(dimension + 1, -1) section = PETSc.Section().create(comm=mesh._comm) get_chart(dm.dm, &pStart, &pEnd) section.setChart(pStart, pEnd) - if boundary_set: - renumbering, (constrainedStart, constrainedEnd) = plex_renumbering(dm, - mesh._entity_classes, reordering=mesh._default_reordering, boundary_set=boundary_set) + if boundary_set and not extruded: + renumbering = plex_renumbering(dm, mesh._entity_classes, reordering=mesh._default_reordering, boundary_set=boundary_set) else: renumbering = mesh._dm_renumbering - constrainedStart = -1 - constrainedEnd = -1 CHKERR(PetscSectionSetPermutation(section.sec, renumbering.iset)) - dimension = get_topological_dimension(dm) - nodes = nodes_per_entity.reshape(dimension + 1, -1) for i in range(dimension + 1): get_depth_stratum(dm.dm, i, &pStart, &pEnd) # gets all points at dim i if not variable: @@ -1260,9 +1263,27 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1, boundary ndof = layers*nodes[i, 0] + (layers - 1)*nodes[i, 1] CHKERR(PetscSectionSetDof(section.sec, p, block_size * ndof)) + if boundary_set and extruded and variable: + raise NotImplementedError("Not implemented for variable layer extrusion") if boundary_set: + # Handle "bottom" and "top" first. + if "bottom" in boundary_set and "top" in boundary_set: + factor = 2 + elif "bottom" in boundary_set or "top" in boundary_set: + factor = 1 + else: + factor = 0 + if factor > 0: + for i in range(dimension + 1): + get_depth_stratum(dm.dm, i, &pStart, &pEnd) + dof = nodes_per_entity[i, 0] + for p in range(pStart, pEnd): + CHKERR(PetscSectionSetConstraintDof(section.sec, p, factor * dof)) + # Potentially overwrite ds_t and dS_t constrained DoFs set in the {"bottom", "top"} cases. for marker in boundary_set: - if marker == "on_boundary": + if marker in ["bottom", "top"]: + continue + elif marker == "on_boundary": label = "exterior_facets" marker = 1 else: @@ -1276,11 +1297,36 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1, boundary CHKERR(PetscSectionGetDof(section.sec, p, &dof)) CHKERR(PetscSectionSetConstraintDof(section.sec, p, dof)) section.setUp() - if boundary_set: # have to loop again as we need to call section.setUp() first + CHKERR(PetscSectionGetMaxDof(section.sec, &dof)) + CHKERR(PetscMalloc1(dof, &dof_array)) + for i in range(dof): + dof_array[i] = -1 + if "bottom" in boundary_set or "top" in boundary_set: + for i in range(dimension + 1): + get_depth_stratum(dm.dm, i, &pStart, &pEnd) + if pEnd == pStart: + continue + dof = nodes_per_entity[i, 0] + j = 0 + if "bottom" in boundary_set: + for k in range(dof): + dof_array[j] = k + j += 1 + if "top" in boundary_set: + offset_top = (nodes_per_entity[i, 0] + nodes_per_entity[i, 1]) * (mesh.layers - 1) + for k in range(dof): + dof_array[j] = offset_top + k + j += 1 + for p in range(pStart, pEnd): + # Potentially set wrong values for ds_t and dS_t constrained DoFs here, + # but we will overwrite them in the below. + CHKERR(PetscSectionSetConstraintIndices(section.sec, p, dof_array)) for marker in boundary_set: - if marker == "on_boundary": + if marker in ["bottom", "top"]: + continue + elif marker == "on_boundary": label = "exterior_facets" marker = 1 else: @@ -1289,24 +1335,24 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1, boundary if n == 0: continue points = dm.getStratumIS(label, marker).indices - CHKERR(PetscSectionGetMaxDof(section.sec, &dof)) - CHKERR(PetscMalloc1(dof, &dof_array)) for i in range(n): p = points[i] CHKERR(PetscSectionGetDof(section.sec, p, &dof)) for j in range(dof): dof_array[j] = j CHKERR(PetscSectionSetConstraintIndices(section.sec, p, dof_array)) - CHKERR(PetscFree(dof_array)) - + CHKERR(PetscFree(dof_array)) constrained_nodes = 0 - - CHKERR(ISGetIndices(renumbering.iset, &entity_point_map)) - for entity in range(constrainedStart, constrainedEnd): - CHKERR(PetscSectionGetDof(section.sec, entity_point_map[entity], &dof)) + get_chart(dm.dm, &pStart, &pEnd) + point_sf = dm.getPointSF() + CHKERR(PetscSFGetGraph(point_sf.sf, NULL, &nleaves, &ilocal, NULL)) + for p in range(pStart, pEnd): + CHKERR(PetscSectionGetConstraintDof(section.sec, p, &dof)) constrained_nodes += dof - CHKERR(ISRestoreIndices(renumbering.iset, &entity_point_map)) - + for i in range(nleaves): + p = ilocal[i] if ilocal else i + CHKERR(PetscSectionGetConstraintDof(section.sec, p, &dof)) + constrained_nodes -= dof return section, constrained_nodes @@ -2460,7 +2506,7 @@ def plex_renumbering(PETSc.DM plex, perm_is.setType("general") CHKERR(ISGeneralSetIndices(perm_is.iset, pEnd - pStart, perm, PETSC_OWN_POINTER)) - return perm_is, (lidx[1], lidx[3]) + return perm_is @cython.boundscheck(False) @cython.wraparound(False) @@ -3310,23 +3356,31 @@ def make_global_numbering(PETSc.Section lsec, PETSc.Section gsec): :arg lsec: Section describing local dof layout and numbers. :arg gsec: Section describing global dof layout and numbers.""" cdef: - PetscInt c, p, pStart, pEnd, dof, cdof, loff, goff + PetscInt c, cc, p, pStart, pEnd, dof, cdof, loff, goff np.ndarray val + PetscInt *dof_array = NULL val = np.empty(lsec.getStorageSize(), dtype=IntType) pStart, pEnd = lsec.getChart() - for p in range(pStart, pEnd): CHKERR(PetscSectionGetDof(lsec.sec, p, &dof)) CHKERR(PetscSectionGetConstraintDof(lsec.sec, p, &cdof)) if dof > 0: CHKERR(PetscSectionGetOffset(lsec.sec, p, &loff)) CHKERR(PetscSectionGetOffset(gsec.sec, p, &goff)) + goff = cabs(goff) if cdof > 0: + CHKERR(PetscSectionGetConstraintIndices(lsec.sec, p, &dof_array)) + for c in range(dof): + val[loff + c] = -2 + for c in range(cdof): + val[loff + dof_array[c]] = -1 + cc = 0 for c in range(dof): - val[loff + c] = -1 + if val[loff + c] < -1: + val[loff + c] = goff + cc + cc += 1 else: - goff = cabs(goff) for c in range(dof): val[loff + c] = goff + c return val diff --git a/firedrake/cython/petschdr.pxi b/firedrake/cython/petschdr.pxi index 55786e7184..9a0bff609d 100644 --- a/firedrake/cython/petschdr.pxi +++ b/firedrake/cython/petschdr.pxi @@ -97,6 +97,7 @@ cdef extern from "petscis.h" nogil: int PetscSectionGetConstraintDof(PETSc.PetscSection,PetscInt,PetscInt*) int PetscSectionSetConstraintDof(PETSc.PetscSection,PetscInt,PetscInt) int PetscSectionSetConstraintIndices(PETSc.PetscSection,PetscInt, PetscInt[]) + int PetscSectionGetConstraintIndices(PETSc.PetscSection,PetscInt, const PetscInt**) int PetscSectionGetMaxDof(PETSc.PetscSection,PetscInt*) int PetscSectionSetPermutation(PETSc.PetscSection,PETSc.PetscIS) int ISGetIndices(PETSc.PetscIS,PetscInt*[]) diff --git a/firedrake/functionspaceimpl.py b/firedrake/functionspaceimpl.py index 65592f6659..8fc81244f7 100644 --- a/firedrake/functionspaceimpl.py +++ b/firedrake/functionspaceimpl.py @@ -14,6 +14,7 @@ import finat.ufl from pyop2 import op2, mpi +from pyop2.utils import as_tuple from firedrake import dmhooks, utils from firedrake.functionspacedata import get_shared_data, create_element @@ -876,6 +877,16 @@ class RestrictedFunctionSpace(FunctionSpace): """ def __init__(self, function_space, boundary_set=frozenset(), name=None): label = "" + boundary_set_ = [] + for boundary_domain in boundary_set: + if isinstance(boundary_domain, str): + boundary_set_.append(boundary_domain) + else: + # Currently, can not handle intersection of boundaries; + # e.g., boundary_set = [(1, 2)], which is different from [1, 2]. + bd, = as_tuple(boundary_domain) + boundary_set_.append(bd) + boundary_set = boundary_set_ for boundary_domain in boundary_set: label += str(boundary_domain) label += "_" @@ -896,7 +907,8 @@ def set_shared_data(self): self.node_set = sdata.node_set r"""A :class:`pyop2.types.set.Set` representing the function space nodes.""" self.dof_dset = op2.DataSet(self.node_set, self.shape or 1, - name="%s_nodes_dset" % self.name) + name="%s_nodes_dset" % self.name, + apply_local_global_filter=sdata.extruded) r"""A :class:`pyop2.types.dataset.DataSet` representing the function space degrees of freedom.""" diff --git a/firedrake/mesh.py b/firedrake/mesh.py index 7c99b47972..2c0ef9e198 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -1245,7 +1245,7 @@ def _renumber_entities(self, reorder): else: # No reordering reordering = None - return dmcommon.plex_renumbering(self.topology_dm, self._entity_classes, reordering)[0] + return dmcommon.plex_renumbering(self.topology_dm, self._entity_classes, reordering) @utils.cached_property def cell_closure(self): @@ -1979,7 +1979,7 @@ def _renumber_entities(self, reorder): perm_is.setIndices(perm) return perm_is else: - return dmcommon.plex_renumbering(self.topology_dm, self._entity_classes, None)[0] + return dmcommon.plex_renumbering(self.topology_dm, self._entity_classes, None) @utils.cached_property # TODO: Recalculate if mesh moves def cell_closure(self): diff --git a/firedrake/slate/static_condensation/hybridization.py b/firedrake/slate/static_condensation/hybridization.py index 997f969684..ce0f184a07 100644 --- a/firedrake/slate/static_condensation/hybridization.py +++ b/firedrake/slate/static_condensation/hybridization.py @@ -1,5 +1,4 @@ import functools -import numbers import numpy as np import ufl @@ -11,7 +10,6 @@ from firedrake.petsc import PETSc from firedrake.parloops import par_loop, READ, INC from firedrake.slate.slate import Tensor, AssembledVector -from pyop2.utils import as_tuple from firedrake.slate.static_condensation.la_utils import SchurComplementBuilder from firedrake.ufl_expr import adjoint @@ -153,11 +151,7 @@ def initialize(self, pc): if bc.function_space().index != self.vidx: raise NotImplementedError("Dirichlet bc set on unsupported space.") # append the set of sub domains - subdom = bc.sub_domain - if isinstance(subdom, str): - neumann_subdomains |= set([subdom]) - else: - neumann_subdomains |= set(as_tuple(subdom, numbers.Integral)) + neumann_subdomains |= set(bc.sub_domain) # separate out the top and bottom bcs extruded_neumann_subdomains = neumann_subdomains & {"top", "bottom"} diff --git a/pyop2/types/dat.py b/pyop2/types/dat.py index f41ee3d5b9..2eb288dcf8 100644 --- a/pyop2/types/dat.py +++ b/pyop2/types/dat.py @@ -807,15 +807,37 @@ def _vec(self): # But use getSizes to save an Allreduce in computing the # global size. size = self.dataset.layout_vec.getSizes() - data = self._data[:size[0]] + if self.dataset._apply_local_global_filter: + data = self._data_filtered + else: + data = self._data[:size[0]] return PETSc.Vec().createWithArray(data, size=size, bsize=self.cdim, comm=self.comm) + @utils.cached_property + def _data_filtered(self): + size, _ = self.dataset.layout_vec.getSizes() + size //= self.dataset.layout_vec.block_size + data = self._data[:size] + return np.empty_like(data) + + @utils.cached_property + def _data_filter(self): + lgmap = self.dataset.lgmap + n = self.dataset.size + lgmap_owned = lgmap.block_indices[:n] + return lgmap_owned >= 0 + @contextlib.contextmanager def vec_context(self, access): r"""A context manager for a :class:`PETSc.Vec` from a :class:`Dat`. :param access: Access descriptor: READ, WRITE, or RW.""" + size = self.dataset.size + if self.dataset._apply_local_global_filter and access is not Access.WRITE: + self._data_filtered[:] = self._data[:size][self._data_filter] yield self._vec + if self.dataset._apply_local_global_filter and access is not Access.READ: + self._data[:size][self._data_filter] = self._data_filtered[:] if access is not Access.READ: self.halo_valid = False diff --git a/pyop2/types/dataset.py b/pyop2/types/dataset.py index 3b4f4bfd8a..087d9b091d 100644 --- a/pyop2/types/dataset.py +++ b/pyop2/types/dataset.py @@ -21,8 +21,9 @@ class DataSet(caching.ObjectCached): @utils.validate_type(('iter_set', Set, ex.SetTypeError), ('dim', (numbers.Integral, tuple, list), ex.DimTypeError), - ('name', str, ex.NameTypeError)) - def __init__(self, iter_set, dim=1, name=None): + ('name', str, ex.NameTypeError), + ('apply_local_global_filter', bool, ex.DataTypeError)) + def __init__(self, iter_set, dim=1, name=None, apply_local_global_filter=False): if isinstance(iter_set, ExtrudedSet): raise NotImplementedError("Not allowed!") if self._initialized: @@ -35,18 +36,19 @@ def __init__(self, iter_set, dim=1, name=None): self._cdim = np.prod(self._dim).item() self._name = name or "dset_#x%x" % id(self) self._initialized = True + self._apply_local_global_filter = apply_local_global_filter @classmethod def _process_args(cls, *args, **kwargs): return (args[0], ) + args, kwargs @classmethod - def _cache_key(cls, iter_set, dim=1, name=None): + def _cache_key(cls, iter_set, dim=1, name=None, apply_local_global_filter=False): return (iter_set, utils.as_tuple(dim, numbers.Integral)) @utils.cached_property def _wrapper_cache_key_(self): - return (type(self), self.dim, self._set._wrapper_cache_key_) + return (type(self), self.dim, self._set._wrapper_cache_key_, self._apply_local_global_filter) def __getstate__(self): """Extract state to pickle.""" @@ -97,11 +99,11 @@ def __len__(self): return 1 def __str__(self): - return "OP2 DataSet: %s on set %s, with dim %s" % \ - (self._name, self._set, self._dim) + return "OP2 DataSet: %s on set %s, with dim %s, %s" % \ + (self._name, self._set, self._dim, self._apply_local_global_filter) def __repr__(self): - return "DataSet(%r, %r, %r)" % (self._set, self._dim, self._name) + return "DataSet(%r, %r, %r, %r)" % (self._set, self._dim, self._name, self._apply_local_global_filter) def __contains__(self, dat): """Indicate whether a given Dat is compatible with this DataSet.""" @@ -501,6 +503,8 @@ def lgmap(self): tmp_indices = np.searchsorted(current_offsets, l2g, side="right") - 1 idx[:] = l2g[:] - current_offsets[tmp_indices] + \ all_field_offsets[tmp_indices] + all_local_offsets[tmp_indices] + # Explicitly set -1 for constrained DoFs. + idx[l2g < 0] = -1 self.comm.Allgather(owned_sz, current_offsets[1:]) all_local_offsets += current_offsets[1:] start += s.total_size * s.cdim diff --git a/tests/firedrake/regression/test_restricted_function_space.py b/tests/firedrake/regression/test_restricted_function_space.py index 51e2139ae6..dc9a2ecc64 100644 --- a/tests/firedrake/regression/test_restricted_function_space.py +++ b/tests/firedrake/regression/test_restricted_function_space.py @@ -228,3 +228,134 @@ def test_poisson_mixed_restricted_spaces(i, j): assert errornorm(w.subfunctions[0], w2.subfunctions[0]) < 1.e-12 assert errornorm(w.subfunctions[1], w2.subfunctions[1]) < 1.e-12 + + +@pytest.mark.parallel(nprocs=2) +def test_restricted_function_space_extrusion_basics(): + # + # rank 0 rank 1 + # + # plex points: + # + # +-------+-------+ +-------+-------+ + # | | | | | | + # | | | | | | + # | | | | | | + # +-------+-------+ +-------+-------+ + # 2 0 (3) (1) (4) (4) (1) 2 0 3 () = ghost + # + # mesh._dm_renumbering: + # + # [0, 2, 3, 1, 4] [0, 3, 2, 1, 4] + # + # Local DoFs: + # + # 5---2--(8)(11)(14) (14)(11)--8---2---5 + # | | | | | | + # 4 1 (7)(10)(13) (13)(10) 7 1 4 + # | | | | | | + # 3---0--(6)-(9)(12) (12)-(9)--6---0---3 () = ghost + # + # Global DoFs: + # + # 3---1---9---5---7 + # | | | + # 2 0 8 4 6 + # | | | + # x---x---x---x---x + # + # LGMap: + # + # rank 0 : [-1, 0, 1, -1, 2, 3, -1, 8, 9, -1, 4, 5, -1, 6, 7] + # rank 1 : [-1, 4, 5, -1, 6, 7, -1, 8, 9, -1, 0, 1, -1, 2, 3] + mesh = UnitIntervalMesh(2) + extm = ExtrudedMesh(mesh, 1) + V = FunctionSpace(extm, "CG", 2) + V_res = RestrictedFunctionSpace(V, boundary_set=["bottom"]) + # Check lgmap. + lgmap = V_res.topological.local_to_global_map(None) + if mesh.comm.rank == 0: + lgmap_expected = [-1, 0, 1, -1, 2, 3, -1, 8, 9, -1, 4, 5, -1, 6, 7] + else: + lgmap_expected = [-1, 4, 5, -1, 6, 7, -1, 8, 9, -1, 0, 1, -1, 2, 3] + assert np.allclose(lgmap.indices, lgmap_expected) + # Check vec. + n = V_res.dof_dset.size + lgmap_owned = lgmap.indices[:n] + local_global_filter = lgmap_owned >= 0 + local_array = 1.0 * np.arange(V_res.dof_dset.total_size) + f = Function(V_res) + f.dat.data_wo_with_halos[:] = local_array + with f.dat.vec as v: + assert np.allclose(v.getArray(), local_array[:n][local_global_filter]) + v *= 2. + assert np.allclose(f.dat.data_ro_with_halos[:n][local_global_filter], 2. * local_array[:n][local_global_filter]) + # Solve Poisson problem. + x, y = SpatialCoordinate(extm) + normal = FacetNormal(extm) + exact = Function(V_res).interpolate(x**2 * y**2) + exact_grad = as_vector([2 * x * y**2, 2 * x**2 * y]) + u = TrialFunction(V_res) + v = TestFunction(V_res) + a = inner(grad(u), grad(v)) * dx + L = inner(-2 * (x**2 + y**2), v) * dx + inner(dot(exact_grad, normal), v) * ds_v(2) + inner(dot(exact_grad, normal), v) * ds_t + bc = DirichletBC(V_res, exact, "bottom") + sol = Function(V_res) + solve(a == L, sol, bcs=[bc]) + assert assemble(inner(sol - exact, sol - exact) * dx)**0.5 < 1.e-15 + + +@pytest.mark.parallel(nprocs=4) +@pytest.mark.parametrize("ncells", [2, 4]) +def test_restricted_function_space_extrusion_poisson(ncells): + mesh = UnitIntervalMesh(ncells) + extm = ExtrudedMesh(mesh, ncells) + subdomain_ids = ["bottom", "top", 1, 2] + V = FunctionSpace(extm, "CG", 4) + V_res = RestrictedFunctionSpace(V, boundary_set=subdomain_ids) + x, y = SpatialCoordinate(extm) + exact = Function(V_res).interpolate(x**2 * y**2) + u = TrialFunction(V_res) + v = TestFunction(V_res) + a = inner(grad(u), grad(v)) * dx + L = inner(-2 * (x**2 + y**2), v) * dx + bc = DirichletBC(V_res, exact, subdomain_ids) + sol = Function(V_res) + solve(a == L, sol, bcs=[bc]) + assert assemble(inner(sol - exact, sol - exact) * dx)**0.5 < 1.e-15 + + +@pytest.mark.parallel(nprocs=4) +@pytest.mark.parametrize("ncells", [2, 16]) +def test_restricted_function_space_extrusion_stokes(ncells): + mesh = UnitIntervalMesh(ncells) + extm = ExtrudedMesh(mesh, ncells) + subdomain_ids = [1, 2, "bottom"] + f_value_0 = as_vector([1., 1.]) + bc_value_0 = as_vector([0., 0.]) + # Solve reference problem. + V = VectorFunctionSpace(extm, "CG", 2) + Q = FunctionSpace(extm, "CG", 1) + W = V * Q + u, p = TrialFunctions(W) + v, q = TestFunctions(W) + a = inner(2 * sym(grad(u)), grad(v)) * dx - inner(p, div(v)) * dx + inner(div(u), q) * dx + L = inner(f_value_0, v) * dx + bc = DirichletBC(W.sub(0), bc_value_0, subdomain_ids) + sol = Function(W) + solve(a == L, sol, bcs=[bc]) + # Solve problem on restricted space. + V_res = RestrictedFunctionSpace(V, boundary_set=subdomain_ids) + W_res = V_res * Q + u_res, p = TrialFunctions(W_res) + v_res, q = TestFunctions(W_res) + a_res = inner(2 * sym(grad(u_res)), grad(v_res)) * dx - inner(p, div(v_res)) * dx + inner(div(u_res), q) * dx + L_res = inner(f_value_0, v_res) * dx + bc_res = DirichletBC(W_res.sub(0), bc_value_0, subdomain_ids) + sol_res = Function(W_res) + solve(a_res == L_res, sol_res, bcs=[bc_res]) + # Compare. + assert assemble(inner(sol_res - sol, sol_res - sol) * dx)**0.5 < 1.e-15 + # -- Actually, the ordering is the same. + assert np.allclose(sol_res.subfunctions[0].dat.data_ro_with_halos, sol.subfunctions[0].dat.data_ro_with_halos) + assert np.allclose(sol_res.subfunctions[1].dat.data_ro_with_halos, sol.subfunctions[1].dat.data_ro_with_halos) diff --git a/tests/pyop2/test_api.py b/tests/pyop2/test_api.py index 468d175587..6fe24460c4 100644 --- a/tests/pyop2/test_api.py +++ b/tests/pyop2/test_api.py @@ -504,8 +504,8 @@ def test_dset_repr(self, dset): def test_dset_str(self, dset): "DataSet should have the expected string representation." - assert str(dset) == "OP2 DataSet: %s on set %s, with dim %s" \ - % (dset.name, dset.set, dset.dim) + assert str(dset) == "OP2 DataSet: %s on set %s, with dim %s, %s" \ + % (dset.name, dset.set, dset.dim, dset._apply_local_global_filter) def test_dset_eq(self, dset): "The equality test for DataSets is same dim and same set"