Skip to content

Commit

Permalink
allow for Petrov-Galerkin formulations
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Dec 11, 2024
1 parent 17c4106 commit f641b25
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 46 deletions.
83 changes: 37 additions & 46 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,25 +1180,24 @@ def allocate(self):

def _apply_bc(self, tensor, bc):
# TODO Maybe this could be a singledispatchmethod?
if isinstance(bc, DirichletBC):
self._apply_dirichlet_bc(tensor, bc)
elif isinstance(bc, EquationBCSplit):
bc.zero(tensor)
type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor)
else:
raise AssertionError

def _apply_dirichlet_bc(self, tensor, bc):
if not self._zero_bc_nodes:
if self._diagonal:
assert isinstance(bc, DirichletBC)
assert not self._zero_bc_nodes
tensor_func = tensor.riesz_representation(riesz_map="l2")
if self._diagonal:
bc.set(tensor_func, 1)
else:
bc.apply(tensor_func)
bc.set(tensor_func, 1)
tensor.assign(tensor_func.riesz_representation(riesz_map="l2"))
else:
bc.zero(tensor)
test, = self._form.arguments()
if test.function_space() == bc.function_space_parent:
if isinstance(bc, DirichletBC):
assert self._zero_bc_nodes
bc.zero(tensor)
elif isinstance(bc, EquationBCSplit):
bc.zero(tensor)
type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor)
else:
raise AssertionError

def _check_tensor(self, tensor):
if tensor.function_space() != self._form.arguments()[0].function_space():
Expand Down Expand Up @@ -1421,31 +1420,29 @@ def _apply_bc(self, tensor, bc):
index = 0 if V.index is None else V.index
space = V if V.parent is None else V.parent
if isinstance(bc, DirichletBC):
if space != spaces[0]:
raise TypeError("bc space does not match the test function space")
elif space != spaces[1]:
raise TypeError("bc space does not match the trial function space")

# Set diagonal entries on bc nodes to 1 if the current
# block is on the matrix diagonal and its index matches the
# index of the function space the bc is defined on.
op2tensor[index, index].set_local_diagonal_entries(bc.nodes, idx=component, diag_val=self.weight)

if space == spaces[0] and space == spaces[1]:
# Set diagonal entries on bc nodes to 1 if the current
# block is on the matrix diagonal and its index matches the
# index of the function space the bc is defined on.
op2tensor[index, index].set_local_diagonal_entries(bc.nodes, idx=component, diag_val=self.weight)
# Handle off-diagonal block involving real function space.
# "lgmaps" is correctly constructed in _matrix_arg, but
# is ignored by PyOP2 in this case.
# Walk through row blocks associated with index.
for j, s in enumerate(space):
if j != index and s.ufl_element().family() == "Real":
self._apply_bcs_mat_real_block(op2tensor, index, j, component, bc.node_set)
if space == spaces[0]:
for j, s in enumerate(spaces[1]):
if j != index and s.ufl_element().family() == "Real":
self._apply_bcs_mat_real_block(op2tensor, index, j, component, bc.node_set)
# Walk through col blocks associated with index.
for i, s in enumerate(space):
if i != index and s.ufl_element().family() == "Real":
self._apply_bcs_mat_real_block(op2tensor, i, index, component, bc.node_set)
if space == spaces[1]:
for i, s in enumerate(spaces[0]):
if i != index and s.ufl_element().family() == "Real":
self._apply_bcs_mat_real_block(op2tensor, i, index, component, bc.node_set)
elif isinstance(bc, EquationBCSplit):
for j, s in enumerate(spaces[1]):
if s.ufl_element().family() == "Real":
self._apply_bcs_mat_real_block(op2tensor, index, j, component, bc.node_set)
if space == spaces[0]:
for j, s in enumerate(spaces[1]):
if s.ufl_element().family() == "Real":
self._apply_bcs_mat_real_block(op2tensor, index, j, component, bc.node_set)
type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False).assemble(tensor=tensor)
else:
raise AssertionError
Expand Down Expand Up @@ -1889,19 +1886,13 @@ def get_indicess(self):

def _filter_bcs(self, row, col):
assert len(self._form.arguments()) == 2 and not self._diagonal
bcrow = [bc for bc in self._bcs if bc.function_space_parent == self.test_function_space]
bccol = [bc for bc in self._bcs if bc.function_space_parent == self.trial_function_space and isinstance(bc, DirichletBC)]
if len(self.test_function_space) > 1:
bcrow = tuple(bc for bc in self._bcs
if bc.function_space_index() == row)
else:
bcrow = self._bcs

bcrow = [bc for bc in bcrow if bc.function_space_index() == row]
if len(self.trial_function_space) > 1:
bccol = tuple(bc for bc in self._bcs
if bc.function_space_index() == col
and isinstance(bc, DirichletBC))
else:
bccol = tuple(bc for bc in self._bcs if isinstance(bc, DirichletBC))
return bcrow, bccol
bccol = [bc for bc in bccol if bc.function_space_index() == col]
return tuple(bcrow), tuple(bccol)

def needs_unrolling(self):
"""Do we need to address matrix elements directly rather than in
Expand Down
1 change: 1 addition & 0 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(self, V, sub_domain):
else:
# All done
break
self.function_space_parent = fs
# Used for indexing functions passed in.
self._indices = tuple(reversed(indices))
# init bcs
Expand Down

0 comments on commit f641b25

Please sign in to comment.