diff --git a/firedrake/assemble.py b/firedrake/assemble.py index f451b3f596..feb43f29bb 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -1180,25 +1180,24 @@ def allocate(self): def _apply_bc(self, tensor, bc): # TODO Maybe this could be a singledispatchmethod? - if isinstance(bc, DirichletBC): - self._apply_dirichlet_bc(tensor, bc) - elif isinstance(bc, EquationBCSplit): - bc.zero(tensor) - type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, - zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) - else: - raise AssertionError - - def _apply_dirichlet_bc(self, tensor, bc): - if not self._zero_bc_nodes: + if self._diagonal: + assert isinstance(bc, DirichletBC) + assert not self._zero_bc_nodes tensor_func = tensor.riesz_representation(riesz_map="l2") - if self._diagonal: - bc.set(tensor_func, 1) - else: - bc.apply(tensor_func) + bc.set(tensor_func, 1) tensor.assign(tensor_func.riesz_representation(riesz_map="l2")) else: - bc.zero(tensor) + test, = self._form.arguments() + if test.function_space() == bc.function_space_parent: + if isinstance(bc, DirichletBC): + assert self._zero_bc_nodes + bc.zero(tensor) + elif isinstance(bc, EquationBCSplit): + bc.zero(tensor) + type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, + zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) + else: + raise AssertionError def _check_tensor(self, tensor): if tensor.function_space() != self._form.arguments()[0].function_space(): @@ -1421,31 +1420,29 @@ def _apply_bc(self, tensor, bc): index = 0 if V.index is None else V.index space = V if V.parent is None else V.parent if isinstance(bc, DirichletBC): - if space != spaces[0]: - raise TypeError("bc space does not match the test function space") - elif space != spaces[1]: - raise TypeError("bc space does not match the trial function space") - - # Set diagonal entries on bc nodes to 1 if the current - # block is on the matrix diagonal and its index matches the - # index of the function space the bc is defined on. - op2tensor[index, index].set_local_diagonal_entries(bc.nodes, idx=component, diag_val=self.weight) - + if space == spaces[0] and space == spaces[1]: + # Set diagonal entries on bc nodes to 1 if the current + # block is on the matrix diagonal and its index matches the + # index of the function space the bc is defined on. + op2tensor[index, index].set_local_diagonal_entries(bc.nodes, idx=component, diag_val=self.weight) # Handle off-diagonal block involving real function space. # "lgmaps" is correctly constructed in _matrix_arg, but # is ignored by PyOP2 in this case. # Walk through row blocks associated with index. - for j, s in enumerate(space): - if j != index and s.ufl_element().family() == "Real": - self._apply_bcs_mat_real_block(op2tensor, index, j, component, bc.node_set) + if space == spaces[0]: + for j, s in enumerate(spaces[1]): + if j != index and s.ufl_element().family() == "Real": + self._apply_bcs_mat_real_block(op2tensor, index, j, component, bc.node_set) # Walk through col blocks associated with index. - for i, s in enumerate(space): - if i != index and s.ufl_element().family() == "Real": - self._apply_bcs_mat_real_block(op2tensor, i, index, component, bc.node_set) + if space == spaces[1]: + for i, s in enumerate(spaces[0]): + if i != index and s.ufl_element().family() == "Real": + self._apply_bcs_mat_real_block(op2tensor, i, index, component, bc.node_set) elif isinstance(bc, EquationBCSplit): - for j, s in enumerate(spaces[1]): - if s.ufl_element().family() == "Real": - self._apply_bcs_mat_real_block(op2tensor, index, j, component, bc.node_set) + if space == spaces[0]: + for j, s in enumerate(spaces[1]): + if s.ufl_element().family() == "Real": + self._apply_bcs_mat_real_block(op2tensor, index, j, component, bc.node_set) type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False).assemble(tensor=tensor) else: raise AssertionError @@ -1889,19 +1886,13 @@ def get_indicess(self): def _filter_bcs(self, row, col): assert len(self._form.arguments()) == 2 and not self._diagonal + bcrow = [bc for bc in self._bcs if bc.function_space_parent == self.test_function_space] + bccol = [bc for bc in self._bcs if bc.function_space_parent == self.trial_function_space and isinstance(bc, DirichletBC)] if len(self.test_function_space) > 1: - bcrow = tuple(bc for bc in self._bcs - if bc.function_space_index() == row) - else: - bcrow = self._bcs - + bcrow = [bc for bc in bcrow if bc.function_space_index() == row] if len(self.trial_function_space) > 1: - bccol = tuple(bc for bc in self._bcs - if bc.function_space_index() == col - and isinstance(bc, DirichletBC)) - else: - bccol = tuple(bc for bc in self._bcs if isinstance(bc, DirichletBC)) - return bcrow, bccol + bccol = [bc for bc in bccol if bc.function_space_index() == col] + return tuple(bcrow), tuple(bccol) def needs_unrolling(self): """Do we need to address matrix elements directly rather than in diff --git a/firedrake/bcs.py b/firedrake/bcs.py index f0d007ede4..ae1f0b655c 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -63,6 +63,7 @@ def __init__(self, V, sub_domain): else: # All done break + self.function_space_parent = fs # Used for indexing functions passed in. self._indices = tuple(reversed(indices)) # init bcs