Skip to content

Commit

Permalink
Correct broken LGMap
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Mar 1, 2024
1 parent 49390c7 commit 9fc3d85
Showing 1 changed file with 29 additions and 17 deletions.
46 changes: 29 additions & 17 deletions firedrake/preconditioners/fdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,31 +230,40 @@ def allocate_matrix(self, Amat, V, J, bcs, fcp, pmat_type, use_static_condensati
self.fises = PETSc.IS().createBlock(Vbig.value_size, fdofs, comm=PETSc.COMM_SELF)

# Create data structures needed for assembly
def mask_local_indices(indices):
local_indices = numpy.arange(len(indices), dtype=PETSc.IntType)
local_indices[indices < 0] = -1
return local_indices

def get_lgmap(V, broken=False):
lgmap = V.dof_dset.lgmap
if broken:
mesh = V.mesh()
ncell = mesh.cell_set.size
indices = V.cell_node_list[:ncell]
indices = V.cell_node_list[:ncell].copy()
if V.extruded:
indices = extrude_cell_node_list(indices, V.offset, mesh.layers)
lgmap.apply(indices, result=indices)
return PETSc.LGMap().create(indices, bsize=lgmap.getBlockSize(), comm=lgmap.getComm())
else:
return V.local_to_global_map([], lgmap=lgmap, non_ghost_cells=True)

self.non_ghosted_lgmaps = {Vsub: get_lgmap(Vsub) for Vsub in V}
self.lgmaps = {Vsub: Vsub.local_to_global_map([bc for bc in bcs if bc.function_space() == Vsub]) for Vsub in V}
self.indices = {Vsub: op2.Dat(Vsub.dof_dset,
mask_local_indices(self.lgmaps[Vsub].indices),
dtype=PETSc.IntType) for Vsub in V}
def mask_local_indices(V, lgmap, broken=False):
if broken:
indices = V.cell_node_list.copy()
if V.extruded:
indices = extrude_cell_node_list(indices, V.offset, V.mesh().layers)
mask = indices.flatten()
lgmap.apply(mask, result=mask)
V = FunctionSpace(V.mesh(), finat.ufl.BrokenElement(V.ufl_element()))
else:
mask = lgmap.indices

indices = numpy.arange(len(mask), dtype=PETSc.IntType)
indices[mask == -1] = -1
indices_dat = op2.Dat(V.dof_dset, indices, dtype=PETSc.IntType)
indices_acc = indices_dat(op2.READ, V.cell_node_map())
return indices_acc

#self.indices = {Vsub: op2.Dat(Vsub.dof_dset, numpy.arange(Vsub.dof_dset.set.size, dtype=PETSc.IntType), dtype=PETSc.IntType) for Vsub in V}
broken = pmat_type == "is"
self.non_ghosted_lgmaps = {Vsub: get_lgmap(Vsub, broken) for Vsub in V}
self.lgmaps = {Vsub: Vsub.local_to_global_map([bc for bc in bcs if bc.function_space() == Vsub]) for Vsub in V}
self.indices = {Vsub: mask_local_indices(Vsub, self.lgmaps[Vsub], broken) for Vsub in V}
self.coefficients, assembly_callables = self.assemble_coefficients(J, fcp)
self.assemblers = {}
self.kernels = []
Expand Down Expand Up @@ -319,7 +328,7 @@ def get_lgmap(V, broken=False):
P.setPreallocationNNZ((dnz, onz))

P.setOption(PETSc.Mat.Option.NEW_NONZERO_ALLOCATION_ERR, True)
#P.setOption(PETSc.Mat.Option.UNUSED_NONZERO_LOCATION_ERR, ptype != "is")
P.setOption(PETSc.Mat.Option.UNUSED_NONZERO_LOCATION_ERR, ptype != "is")
P.setOption(PETSc.Mat.Option.STRUCTURALLY_SYMMETRIC, on_diag)
P.setOption(PETSc.Mat.Option.KEEP_NONZERO_PATTERN, True)
if ptype.endswith("sbaij"):
Expand All @@ -329,7 +338,7 @@ def get_lgmap(V, broken=False):
self.set_values(P, Vrow, Vcol, addv, mat_type="preallocator")
# populate diagonal entries
if on_diag:
n = len(self.indices[Vrow].data_ro_with_halos)
n = len(self.non_ghosted_lgmaps[Vrow].indices)
i = numpy.arange(n, dtype=PETSc.IntType).reshape(-1, 1)
v = numpy.ones(i.shape, dtype=PETSc.ScalarType)
P.setValuesLocalRCV(i, i, v, addv=addv)
Expand All @@ -339,9 +348,12 @@ def get_lgmap(V, broken=False):
assembly_callables.append(P.zeroEntries)
assembly_callables.append(partial(self.set_values, P, Vrow, Vcol, addv, mat_type=ptype))
if on_diag:
bdofs = numpy.flatnonzero(self.indices[Vrow].data_ro < 0).astype(PETSc.IntType)[:, None]
own = Vrow.dof_dset.layout_vec.getLocalSize()
bdofs = numpy.flatnonzero(self.lgmaps[Vrow].indices[:own] < 0).astype(PETSc.IntType)[:, None]
Vrow.dof_dset.lgmap.apply(bdofs, result=bdofs)
assembly_callables.append(P.assemble)
assembly_callables.append(partial(P.zeroRowsLocal, bdofs, 1.0))
assembly_callables.append(partial(P.zeroRows, bdofs, 1.0))
# assembly_callables.append(P.view)

gamma = self.coefficients.get("facet")
if gamma is not None and gamma.function_space() == Vrow.dual():
Expand Down Expand Up @@ -757,7 +769,7 @@ def set_values(self, A, Vrow, Vcol, addv, mat_type="aij"):
TripleProductKernel(R0, M, C0))
self.kernels.append(element_kernel)
spaces = (Vrow, Vcol)[on_diag:]
indices_acc = tuple(self.indices[V](op2.READ, V.cell_node_map()) for V in spaces)
indices_acc = tuple(self.indices[V] for V in spaces)
coefficients = self.coefficients["cell"]
coefficients_acc = coefficients.dat(op2.READ, coefficients.cell_node_map())
kernel = element_kernel.kernel(on_diag=on_diag, addv=addv)
Expand Down

0 comments on commit 9fc3d85

Please sign in to comment.