Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Commit

Permalink
Merge pull request #674 from OP2/ksagiyam/composed_map
Browse files Browse the repository at this point in the history
add ComposedMap
  • Loading branch information
ksagiyam authored Oct 18, 2022
2 parents a0259a9 + fcf4250 commit 1ba576e
Show file tree
Hide file tree
Showing 8 changed files with 369 additions and 23 deletions.
29 changes: 27 additions & 2 deletions pyop2/codegen/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy
from loopy.types import OpaqueType
from pyop2.global_kernel import (GlobalKernelArg, DatKernelArg, MixedDatKernelArg,
MatKernelArg, MixedMatKernelArg, PermutedMapKernelArg)
MatKernelArg, MixedMatKernelArg, PermutedMapKernelArg, ComposedMapKernelArg)
from pyop2.codegen.representation import (Accumulate, Argument, Comparison,
DummyInstruction, Extent, FixedIndex,
FunctionCall, Index, Indexed,
Expand Down Expand Up @@ -154,6 +154,28 @@ def indexed_vector(self, n, shape, layer=None):
return super().indexed_vector(n, shape, layer=layer, permute=permute)


class CMap(Map):

def __init__(self, *maps_):
# Copy over properties
self.variable = maps_[0].variable
self.unroll = maps_[0].unroll
self.layer_bounds = maps_[0].layer_bounds
self.interior_horizontal = maps_[0].interior_horizontal
self.prefetch = {}
self.values = maps_[0].values
self.offset = maps_[0].offset
self.maps_ = maps_

def indexed(self, multiindex, layer=None):
n, i, f = multiindex
n_ = n
for map_ in reversed(self.maps_):
if map_ is not self.maps_[0]:
n_, (_, _) = map_.indexed(MultiIndex(n_, FixedIndex(0), Index()), layer=None)
return self.maps_[0].indexed(MultiIndex(n_, i, f), layer=layer)


class Pack(metaclass=ABCMeta):

def pick_loop_indices(self, loop_index, layer_index=None, entity_index=None):
Expand Down Expand Up @@ -835,6 +857,8 @@ def _add_map(self, map_, unroll=False):
if isinstance(map_, PermutedMapKernelArg):
imap = self._add_map(map_.base_map, unroll)
map_ = PMap(imap, numpy.asarray(map_.permutation, dtype=IntType))
elif isinstance(map_, ComposedMapKernelArg):
map_ = CMap(*(self._add_map(m, unroll) for m in map_.base_maps))
else:
map_ = Map(interior_horizontal,
(self.bottom_layer, self.top_layer),
Expand Down Expand Up @@ -878,7 +902,8 @@ def wrapper_args(self):
# But we don't need to emit stuff for PMaps because they
# are a Map (already seen + a permutation [encoded in the
# indexing]).
if not isinstance(map_, PMap):
# CMaps do not have their own arguments, either.
if not isinstance(map_, (PMap, CMap)):
args.append(map_.values)
return tuple(args)

Expand Down
20 changes: 20 additions & 0 deletions pyop2/global_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,26 @@ def cache_key(self):
return type(self), self.base_map.cache_key, tuple(self.permutation)


@dataclass(eq=False, init=False)
class ComposedMapKernelArg:
"""Class representing a composed map input to the kernel.
:param base_maps: An arbitrary combination of :class:`MapKernelArg`s, :class:`PermutedMapKernelArg`s, and :class:`ComposedMapKernelArg`s.
"""

def __init__(self, *base_maps):
self.base_maps = base_maps

def __post_init__(self):
for m in self.base_maps:
if not isinstance(m, (MapKernelArg, PermutedMapKernelArg, ComposedMapKernelArg)):
raise TypeError("base_maps must be a combination of MapKernelArgs, PermutedMapKernelArgs, and ComposedMapKernelArgs")

@property
def cache_key(self):
return type(self), tuple(m.cache_key for m in self.base_maps)


@dataclass(frozen=True)
class GlobalKernelArg:
"""Class representing a :class:`pyop2.types.Global` being passed to the kernel.
Expand Down
4 changes: 2 additions & 2 deletions pyop2/op2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

from pyop2.types import (
Set, ExtrudedSet, MixedSet, Subset, DataSet, MixedDataSet,
Map, MixedMap, PermutedMap, Sparsity, Halo,
Map, MixedMap, PermutedMap, ComposedMap, Sparsity, Halo,
Global, GlobalDataSet,
Dat, MixedDat, DatView, Mat
)
Expand All @@ -64,7 +64,7 @@
'MixedSet', 'Subset', 'DataSet', 'GlobalDataSet', 'MixedDataSet',
'Halo', 'Dat', 'MixedDat', 'Mat', 'Global', 'Map', 'MixedMap',
'Sparsity', 'parloop', 'Parloop', 'ParLoop', 'par_loop',
'DatView', 'PermutedMap']
'DatView', 'PermutedMap', 'ComposedMap']


_initialised = False
Expand Down
7 changes: 5 additions & 2 deletions pyop2/parloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
MatKernelArg, MixedMatKernelArg, GlobalKernel)
from pyop2.local_kernel import LocalKernel, CStringLocalKernel, CoffeeLocalKernel, LoopyLocalKernel
from pyop2.types import (Access, Global, Dat, DatView, MixedDat, Mat, Set,
MixedSet, ExtrudedSet, Subset, Map, MixedMap)
MixedSet, ExtrudedSet, Subset, Map, ComposedMap, MixedMap)
from pyop2.utils import cached_property


Expand All @@ -25,7 +25,10 @@ class ParloopArg(abc.ABC):
@staticmethod
def check_map(m):
if configuration["type_check"]:
if m.iterset.total_size > 0 and len(m.values_with_halo) == 0:
if isinstance(m, ComposedMap):
for m_ in m.maps_:
ParloopArg.check_map(m_)
elif m.iterset.total_size > 0 and len(m.values_with_halo) == 0:
raise MapValueError(f"{m} is not initialized")


Expand Down
78 changes: 63 additions & 15 deletions pyop2/sparsity.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ cdef extern from "petsc.h":
PETSC_INSERT_VALUES "INSERT_VALUES"
int PetscCalloc1(size_t, void*)
int PetscMalloc1(size_t, void*)
int PetscMalloc2(size_t, void*, size_t, void*)
int PetscFree(void*)
int PetscFree2(void*,void*)
int MatSetValuesBlockedLocal(PETSc.PetscMat, PetscInt, PetscInt*, PetscInt, PetscInt*,
PetscScalar*, PetscInsertMode)
int MatSetValuesLocal(PETSc.PetscMat, PetscInt, PetscInt*, PetscInt, PetscInt*,
Expand Down Expand Up @@ -193,7 +195,9 @@ def fill_with_zeros(PETSc.Mat mat not None, dims, maps, iteration_regions, set_d
PetscScalar zero = 0.0
PetscInt nrow, ncol
PetscInt rarity, carity, tmp_rarity, tmp_carity
PetscInt[:, ::1] rmap, cmap
PetscInt[:, ::1] rmap, cmap, tempmap
PetscInt **rcomposedmaps = NULL, **ccomposedmaps = NULL
PetscInt nrcomposedmaps = 0, nccomposedmaps = 0, rset_entry, cset_entry
PetscInt *rvals
PetscInt *cvals
PetscInt *roffset
Expand All @@ -213,23 +217,52 @@ def fill_with_zeros(PETSc.Mat mat not None, dims, maps, iteration_regions, set_d
set_size = pair[0].iterset.size
if set_size == 0:
continue
# Memoryviews require writeable buffers
rflag = set_writeable(pair[0])
cflag = set_writeable(pair[1])
# Map values
rmap = pair[0].values_with_halo
cmap = pair[1].values_with_halo
rflags = []
cflags = []
if isinstance(pair[0], op2.ComposedMap):
m = pair[0].flattened_maps[0]
rflags.append(set_writeable(m))
rmap = m.values_with_halo
nrcomposedmaps = len(pair[0].flattened_maps) - 1
else:
rflags.append(set_writeable(pair[0])) # Memoryviews require writeable buffers
rmap = pair[0].values_with_halo # Map values
if isinstance(pair[1], op2.ComposedMap):
m = pair[1].flattened_maps[0]
cflags.append(set_writeable(m))
cmap = m.values_with_halo
nccomposedmaps = len(pair[1].flattened_maps) - 1
else:
cflags.append(set_writeable(pair[1]))
cmap = pair[1].values_with_halo
# Handle ComposedMaps
CHKERR(PetscMalloc2(nrcomposedmaps, &rcomposedmaps, nccomposedmaps, &ccomposedmaps))
for i in range(nrcomposedmaps):
m = pair[0].flattened_maps[1 + i]
rflags.append(set_writeable(m))
tempmap = m.values_with_halo
rcomposedmaps[i] = &tempmap[0, 0]
for i in range(nccomposedmaps):
m = pair[1].flattened_maps[1 + i]
cflags.append(set_writeable(m))
tempmap = m.values_with_halo
ccomposedmaps[i] = &tempmap[0, 0]
# Arity of maps
rarity = pair[0].arity
carity = pair[1].arity

if not extruded:
# The non-extruded case is easy, we just walk over the
# rmap and cmap entries and set a block of values.
CHKERR(PetscCalloc1(rarity*carity*rdim*cdim, &values))
for set_entry in range(set_size):
CHKERR(MatSetValuesBlockedLocal(mat.mat, rarity, &rmap[set_entry, 0],
carity, &cmap[set_entry, 0],
rset_entry = <PetscInt>set_entry
cset_entry = <PetscInt>set_entry
for i in range(nrcomposedmaps):
rset_entry = rcomposedmaps[nrcomposedmaps - 1 - i][rset_entry]
for i in range(nccomposedmaps):
cset_entry = ccomposedmaps[nccomposedmaps - 1 - i][cset_entry]
CHKERR(MatSetValuesBlockedLocal(mat.mat, rarity, &rmap[<int>rset_entry, 0],
carity, &cmap[<int>cset_entry, 0],
values, PETSC_INSERT_VALUES))
else:
# The extruded case needs a little more work.
Expand Down Expand Up @@ -268,6 +301,12 @@ def fill_with_zeros(PETSc.Mat mat not None, dims, maps, iteration_regions, set_d
for i in range(carity):
coffset[i] = pair[1].offset[i]
for set_entry in range(set_size):
rset_entry = <PetscInt>set_entry
cset_entry = <PetscInt>set_entry
for i in range(nrcomposedmaps):
rset_entry = rcomposedmaps[nrcomposedmaps - 1 - i][rset_entry]
for i in range(nccomposedmaps):
cset_entry = ccomposedmaps[nccomposedmaps - 1 - i][cset_entry]
if constant_layers:
layer_start = layers[0, 0]
layer_end = layers[0, 1] - 1
Expand All @@ -287,15 +326,15 @@ def fill_with_zeros(PETSc.Mat mat not None, dims, maps, iteration_regions, set_d

# In the case of tmp_rarity == rarity this is just:
#
# rvals[i] = rmap[set_entry, i] + layer_start * roffset[i]
# rvals[i] = rmap[rset_entry, i] + layer_start * roffset[i]
#
# But this means less special casing.
for i in range(tmp_rarity):
rvals[i] = rmap[set_entry, i % rarity] + \
rvals[i] = rmap[<int>rset_entry, i % rarity] + \
(layer_start - layer_bottom + i // rarity) * roffset[i % rarity]
# Ditto
for i in range(tmp_carity):
cvals[i] = cmap[set_entry, i % carity] + \
cvals[i] = cmap[<int>cset_entry, i % carity] + \
(layer_start - layer_bottom + i // carity) * coffset[i % carity]
for layer in range(layer_start, layer_end):
CHKERR(MatSetValuesBlockedLocal(mat.mat, tmp_rarity, rvals,
Expand All @@ -310,6 +349,15 @@ def fill_with_zeros(PETSc.Mat mat not None, dims, maps, iteration_regions, set_d
CHKERR(PetscFree(cvals))
CHKERR(PetscFree(roffset))
CHKERR(PetscFree(coffset))
restore_writeable(pair[0], rflag)
restore_writeable(pair[1], cflag)
CHKERR(PetscFree2(rcomposedmaps, ccomposedmaps))
if isinstance(pair[0], op2.ComposedMap):
for m, rflag in zip(pair[0].flattened_maps, rflags):
restore_writeable(m, rflag)
else:
restore_writeable(pair[0], rflags[0])
if isinstance(pair[1], op2.ComposedMap):
for m, cflag in zip(pair[1].flattened_maps, cflags):
restore_writeable(m, cflag)
else:
restore_writeable(pair[1], cflags[0])
CHKERR(PetscFree(values))
94 changes: 94 additions & 0 deletions pyop2/types/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,13 @@ def __le__(self, o):
"""self<=o if o equals self or self._parent <= o."""
return self == o

@utils.cached_property
def flattened_maps(self):
"""Return all component maps.
This is useful to flatten nested :class:`ComposedMap`s."""
return (self, )


class PermutedMap(Map):
"""Composition of a standard :class:`Map` with a constant permutation.
Expand All @@ -173,6 +180,10 @@ class PermutedMap(Map):
want two global-sized data structures.
"""
def __init__(self, map_, permutation):
if not isinstance(map_, Map):
raise TypeError("map_ must be a Map instance")
if isinstance(map_, ComposedMap):
raise NotImplementedError("PermutedMap of ComposedMap not implemented: simply permute before composing")
self.map_ = map_
self.permutation = np.asarray(permutation, dtype=Map.dtype)
assert (np.unique(permutation) == np.arange(map_.arity, dtype=Map.dtype)).all()
Expand All @@ -192,6 +203,85 @@ def __getattr__(self, name):
return getattr(self.map_, name)


class ComposedMap(Map):
"""Composition of :class:`Map`s, :class:`PermutedMap`s, and/or :class:`ComposedMap`s.
:arg maps_: The maps to compose.
Where normally staging to element data is performed as
.. code-block::
local[i] = global[map[i]]
With a :class:`ComposedMap` we instead get
.. code-block::
local[i] = global[maps_[0][maps_[1][maps_[2][...[i]]]]]
This might be useful if the map you want can be represented by
a composition of existing maps.
"""
def __init__(self, *maps_, name=None):
if not all(isinstance(m, Map) for m in maps_):
raise TypeError("All maps must be Map instances")
for tomap, frommap in zip(maps_[:-1], maps_[1:]):
if tomap.iterset is not frommap.toset:
raise ex.MapTypeError("tomap.iterset must match frommap.toset")
if tomap.comm is not frommap.comm:
raise ex.MapTypeError("All maps needs to share a communicator")
if frommap.arity != 1:
raise ex.MapTypeError("frommap.arity must be 1")
self._iterset = maps_[-1].iterset
self._toset = maps_[0].toset
self.comm = self._toset.comm
self._arity = maps_[0].arity
# Don't call super().__init__() to avoid calling verify_reshape()
self._values = None
self.shape = (self._iterset.total_size, self._arity)
self._name = name or "cmap_#x%x" % id(self)
self._offset = maps_[0]._offset
# A cache for objects built on top of this map
self._cache = {}
self.maps_ = tuple(maps_)

@utils.cached_property
def _kernel_args_(self):
return tuple(itertools.chain(*[m._kernel_args_ for m in self.maps_]))

@utils.cached_property
def _wrapper_cache_key_(self):
return tuple(m._wrapper_cache_key_ for m in self.maps_)

@utils.cached_property
def _global_kernel_arg(self):
from pyop2.global_kernel import ComposedMapKernelArg

return ComposedMapKernelArg(*(m._global_kernel_arg for m in self.maps_))

@utils.cached_property
def values(self):
raise RuntimeError("ComposedMap does not store values directly")

@utils.cached_property
def values_with_halo(self):
raise RuntimeError("ComposedMap does not store values directly")

def __str__(self):
return "OP2 ComposedMap of Maps: [%s]" % ",".join([str(m) for m in self.maps_])

def __repr__(self):
return "ComposedMap(%s)" % ",".join([repr(m) for m in self.maps_])

def __le__(self, o):
raise NotImplementedError("__le__ not implemented for ComposedMap")

@utils.cached_property
def flattened_maps(self):
return tuple(itertools.chain(*(m.flattened_maps for m in self.maps_)))


class MixedMap(Map, caching.ObjectCached):
r"""A container for a bag of :class:`Map`\s."""

Expand Down Expand Up @@ -315,3 +405,7 @@ def __str__(self):

def __repr__(self):
return "MixedMap(%r)" % (self._maps,)

@utils.cached_property
def flattened_maps(self):
raise NotImplementedError("flattend_maps should not be necessary for MixedMap")
4 changes: 2 additions & 2 deletions pyop2/types/mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pyop2.types.access import Access
from pyop2.types.data_carrier import DataCarrier
from pyop2.types.dataset import DataSet, GlobalDataSet, MixedDataSet
from pyop2.types.map import Map
from pyop2.types.map import Map, ComposedMap
from pyop2.types.set import MixedSet, Set, Subset


Expand Down Expand Up @@ -165,7 +165,7 @@ def _process_args(cls, dsets, maps, *, iteration_regions=None, name=None, nest=N
if not isinstance(m, Map):
raise ex.MapTypeError(
"All maps must be of type map, not type %r" % type(m))
if len(m.values_with_halo) == 0 and m.iterset.total_size > 0:
if not isinstance(m, ComposedMap) and len(m.values_with_halo) == 0 and m.iterset.total_size > 0:
raise ex.MapValueError(
"Unpopulated map values when trying to build sparsity.")
# Make sure that the "to" Set of each map in a pair is the set of
Expand Down
Loading

0 comments on commit 1ba576e

Please sign in to comment.