From e0a4d3a9e143ae063015efe52dd454b91929731d Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 24 Jan 2024 16:47:11 +0000 Subject: [PATCH] Passthrough params (#708) Pass objects to local kernels without packing and unpacking. --------- Co-authored-by: Connor Ward --- pyop2/codegen/builder.py | 40 ++++++++++----- pyop2/codegen/representation.py | 2 - pyop2/datatypes.py | 8 +++ pyop2/global_kernel.py | 10 ++++ pyop2/op2.py | 3 +- pyop2/parloop.py | 88 +++++++++++++++++++++++++++++++-- test/unit/test_direct_loop.py | 36 ++++++++++++++ 7 files changed, 170 insertions(+), 17 deletions(-) diff --git a/pyop2/codegen/builder.py b/pyop2/codegen/builder.py index 583e50f10..89cf31fcf 100644 --- a/pyop2/codegen/builder.py +++ b/pyop2/codegen/builder.py @@ -4,9 +4,8 @@ from functools import reduce import numpy -from loopy.types import OpaqueType from pyop2.global_kernel import (GlobalKernelArg, DatKernelArg, MixedDatKernelArg, - MatKernelArg, MixedMatKernelArg, PermutedMapKernelArg, ComposedMapKernelArg) + MatKernelArg, MixedMatKernelArg, PermutedMapKernelArg, ComposedMapKernelArg, PassthroughKernelArg) from pyop2.codegen.representation import (Accumulate, Argument, Comparison, Conditional, DummyInstruction, Extent, FixedIndex, FunctionCall, Index, Indexed, @@ -16,16 +15,13 @@ PreUnpackInst, Product, RuntimeIndex, Sum, Symbol, UnpackInst, Variable, When, Zero) -from pyop2.datatypes import IntType +from pyop2.datatypes import IntType, OpaqueType from pyop2.op2 import (ALL, INC, MAX, MIN, ON_BOTTOM, ON_INTERIOR_FACETS, ON_TOP, READ, RW, WRITE) from pyop2.utils import cached_property -class PetscMat(OpaqueType): - - def __init__(self): - super().__init__(name="Mat") +MatType = OpaqueType("Mat") def _Remainder(a, b): @@ -226,6 +222,23 @@ def emit_unpack_instruction(self, *, loop_indices=None): """Either yield an instruction, or else return an empty tuple (to indicate no instruction)""" +class PassthroughPack(Pack): + def __init__(self, outer): + self.outer = outer + + def kernel_arg(self, loop_indices=None): + return self.outer + + def pack(self, loop_indices=None): + pass + + def emit_pack_instruction(self, **kwargs): + return () + + def emit_unpack_instruction(self, **kwargs): + return () + + class GlobalPack(Pack): def __init__(self, outer, access, init_with_zero=False): @@ -813,7 +826,12 @@ def add_argument(self, arg): dtype = local_arg.dtype interior_horizontal = self.iteration_region == ON_INTERIOR_FACETS - if isinstance(arg, GlobalKernelArg): + if isinstance(arg, PassthroughKernelArg): + argument = Argument((), dtype, pfx="arg") + pack = PassthroughPack(argument) + self.arguments.append(argument) + + elif isinstance(arg, GlobalKernelArg): argument = Argument(arg.dim, dtype, pfx="glob") pack = GlobalPack(argument, access, @@ -856,7 +874,7 @@ def add_argument(self, arg): pack = MixedDatPack(packs, access, dtype, interior_horizontal=interior_horizontal) elif isinstance(arg, MatKernelArg): - argument = Argument((), PetscMat(), pfx="mat") + argument = Argument((), MatType, pfx="mat") maps = tuple(self._add_map(m, arg.unroll) for m in arg.maps) pack = arg.pack(argument, access, maps, @@ -866,7 +884,7 @@ def add_argument(self, arg): elif isinstance(arg, MixedMatKernelArg): packs = [] for a in arg: - argument = Argument((), PetscMat(), pfx="mat") + argument = Argument((), MatType, pfx="mat") maps = tuple(self._add_map(m, a.unroll) for m in a.maps) @@ -949,7 +967,7 @@ def kernel_call(self): args = self.kernel_args access = tuple(self.loopy_argument_accesses) # assuming every index is free index - free_indices = set(itertools.chain.from_iterable(arg.multiindex for arg in args)) + free_indices = set(itertools.chain.from_iterable(arg.multiindex for arg in args if isinstance(arg, Indexed))) # remove runtime index free_indices = tuple(i for i in free_indices if isinstance(i, Index)) if self.pass_layer_to_kernel: diff --git a/pyop2/codegen/representation.py b/pyop2/codegen/representation.py index 89ed46d96..285525078 100644 --- a/pyop2/codegen/representation.py +++ b/pyop2/codegen/representation.py @@ -352,8 +352,6 @@ def __new__(cls, aggregate, multiindex): for index, extent in zip(multiindex, aggregate.shape): if isinstance(index, Index): index.set_extent(extent) - if not multiindex: - return aggregate self = super().__new__(cls) self.children = (aggregate, multiindex) diff --git a/pyop2/datatypes.py b/pyop2/datatypes.py index 41ff3b597..6dccfdd4d 100644 --- a/pyop2/datatypes.py +++ b/pyop2/datatypes.py @@ -69,3 +69,11 @@ def dtype_limits(dtype): except ValueError as e: raise ValueError("Unable to determine numeric limits from %s" % dtype) from e return info.min, info.max + + +class OpaqueType(lp.types.OpaqueType): + def __init__(self, name): + super().__init__(name=name) + + def __repr__(self): + return self.name diff --git a/pyop2/global_kernel.py b/pyop2/global_kernel.py index 91911a253..536d717e9 100644 --- a/pyop2/global_kernel.py +++ b/pyop2/global_kernel.py @@ -206,6 +206,16 @@ def pack(self): return DatPack +class PassthroughKernelArg: + @property + def cache_key(self): + return type(self) + + @property + def maps(self): + return () + + @dataclass(frozen=True) class MixedMatKernelArg: """Class representing a :class:`pyop2.types.MixedDat` being passed to the kernel. diff --git a/pyop2/op2.py b/pyop2/op2.py index 434fc24ac..85788eafa 100644 --- a/pyop2/op2.py +++ b/pyop2/op2.py @@ -36,6 +36,7 @@ import atexit from pyop2.configuration import configuration +from pyop2.datatypes import OpaqueType # noqa: F401 from pyop2.logger import debug, info, warning, error, critical, set_log_level from pyop2.mpi import MPI, COMM_WORLD, collective @@ -52,7 +53,7 @@ from pyop2.global_kernel import (GlobalKernelArg, DatKernelArg, MixedDatKernelArg, # noqa: F401 MatKernelArg, MixedMatKernelArg, MapKernelArg, GlobalKernel) from pyop2.parloop import (GlobalParloopArg, DatParloopArg, MixedDatParloopArg, # noqa: F401 - MatParloopArg, MixedMatParloopArg, Parloop, parloop, par_loop) + MatParloopArg, MixedMatParloopArg, PassthroughArg, Parloop, parloop, par_loop) from pyop2.parloop import (GlobalLegacyArg, DatLegacyArg, MixedDatLegacyArg, # noqa: F401 MatLegacyArg, MixedMatLegacyArg, LegacyParloop, ParLoop) diff --git a/pyop2/parloop.py b/pyop2/parloop.py index cf96ba5b4..384576fa8 100644 --- a/pyop2/parloop.py +++ b/pyop2/parloop.py @@ -13,7 +13,7 @@ from pyop2.datatypes import as_numpy_dtype from pyop2.exceptions import KernelTypeError, MapValueError, SetTypeError from pyop2.global_kernel import (GlobalKernelArg, DatKernelArg, MixedDatKernelArg, - MatKernelArg, MixedMatKernelArg, GlobalKernel) + MatKernelArg, MixedMatKernelArg, PassthroughKernelArg, GlobalKernel) from pyop2.local_kernel import LocalKernel, CStringLocalKernel, LoopyLocalKernel from pyop2.types import (Access, Global, AbstractDat, Dat, DatView, MixedDat, Mat, Set, MixedSet, ExtrudedSet, Subset, Map, ComposedMap, MixedMap) @@ -39,6 +39,10 @@ class GlobalParloopArg(ParloopArg): data: Global + @property + def _kernel_args_(self): + return self.data._kernel_args_ + @property def map_kernel_args(self): return () @@ -59,6 +63,10 @@ def __post_init__(self): if self.map_ is not None: self.check_map(self.map_) + @property + def _kernel_args_(self): + return self.data._kernel_args_ + @property def map_kernel_args(self): return self.map_._kernel_args_ if self.map_ else () @@ -81,6 +89,10 @@ class MixedDatParloopArg(ParloopArg): def __post_init__(self): self.check_map(self.map_) + @property + def _kernel_args_(self): + return self.data._kernel_args_ + @property def map_kernel_args(self): return self.map_._kernel_args_ if self.map_ else () @@ -102,6 +114,10 @@ def __post_init__(self): for m in self.maps: self.check_map(m) + @property + def _kernel_args_(self): + return self.data._kernel_args_ + @property def map_kernel_args(self): rmap, cmap = self.maps @@ -120,12 +136,34 @@ def __post_init__(self): for m in self.maps: self.check_map(m) + @property + def _kernel_args_(self): + return self.data._kernel_args_ + @property def map_kernel_args(self): rmap, cmap = self.maps return tuple(itertools.chain(*itertools.product(rmap._kernel_args_, cmap._kernel_args_))) +@dataclass +class PassthroughParloopArg(ParloopArg): + # a pointer + data: int + + @property + def _kernel_args_(self): + return (self.data,) + + @property + def map_kernel_args(self): + return () + + @property + def maps(self): + return () + + class Parloop: """A parallel loop invocation. @@ -167,7 +205,7 @@ def arglist(self): """Prepare the argument list for calling generated code.""" arglist = self.iterset._kernel_args_ for d in self.arguments: - arglist += d.data._kernel_args_ + arglist += d._kernel_args_ # Collect an ordered set of maps (ignore duplicates) maps = {m: None for d in self.arguments for m in d.map_kernel_args} @@ -224,6 +262,8 @@ def __call__(self): def increment_dat_version(self): """Increment dat versions of :class:`DataCarrier`s in the arguments.""" for lk_arg, gk_arg, pl_arg in self.zipped_arguments: + if isinstance(pl_arg, PassthroughParloopArg): + continue assert isinstance(pl_arg.data, DataCarrier) if lk_arg.access is not Access.READ: if pl_arg.data in self.reduced_globals: @@ -520,6 +560,10 @@ class GlobalLegacyArg(LegacyArg): data: Global access: Access + @property + def dtype(self): + return self.data.dtype + @property def global_kernel_arg(self): return GlobalKernelArg(self.data.dim) @@ -537,6 +581,10 @@ class DatLegacyArg(LegacyArg): map_: Optional[Map] access: Access + @property + def dtype(self): + return self.data.dtype + @property def global_kernel_arg(self): map_arg = self.map_._global_kernel_arg if self.map_ is not None else None @@ -556,6 +604,10 @@ class MixedDatLegacyArg(LegacyArg): map_: MixedMap access: Access + @property + def dtype(self): + return self.data.dtype + @property def global_kernel_arg(self): args = [] @@ -579,6 +631,10 @@ class MatLegacyArg(LegacyArg): lgmaps: Optional[Tuple[Any, Any]] = None needs_unrolling: Optional[bool] = False + @property + def dtype(self): + return self.data.dtype + @property def global_kernel_arg(self): map_args = [m._global_kernel_arg for m in self.maps] @@ -599,6 +655,10 @@ class MixedMatLegacyArg(LegacyArg): lgmaps: Tuple[Any] = None needs_unrolling: Optional[bool] = False + @property + def dtype(self): + return self.data.dtype + @property def global_kernel_arg(self): nrows, ncols = self.data.sparsity.shape @@ -618,6 +678,28 @@ def parloop_arg(self): return MixedMatParloopArg(self.data, tuple(self.maps), self.lgmaps) +@dataclass +class PassthroughArg(LegacyArg): + """Argument that is simply passed to the local kernel without packing. + + :param dtype: The datatype of the argument. This is needed for code generation. + :param data: A pointer to the data. + """ + # We don't know what the local kernel is doing with this argument + access = Access.RW + + dtype: Any + data: int + + @property + def global_kernel_arg(self): + return PassthroughKernelArg() + + @property + def parloop_arg(self): + return PassthroughParloopArg(self.data) + + def ParLoop(*args, **kwargs): return LegacyParloop(*args, **kwargs) @@ -641,7 +723,7 @@ def LegacyParloop(local_knl, iterset, *args, **kwargs): # finish building the local kernel local_knl.accesses = tuple(a.access for a in args) if isinstance(local_knl, CStringLocalKernel): - local_knl.dtypes = tuple(a.data.dtype for a in args) + local_knl.dtypes = tuple(a.dtype for a in args) global_knl_args = tuple(a.global_kernel_arg for a in args) extruded = iterset._extruded diff --git a/test/unit/test_direct_loop.py b/test/unit/test_direct_loop.py index 3d00ac561..2524a78f3 100644 --- a/test/unit/test_direct_loop.py +++ b/test/unit/test_direct_loop.py @@ -34,6 +34,7 @@ import pytest import numpy as np +from petsc4py import PETSc from pyop2 import op2 from pyop2.exceptions import MapValueError @@ -249,6 +250,41 @@ def test_kernel_cplusplus(self, delems): assert (y.data == 10.5).all() + def test_passthrough_mat(self): + niters = 10 + iterset = op2.Set(niters) + + c_kernel = """ +static void mat_inc(Mat mat) { + PetscScalar values[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + PetscInt idxs[] = {0, 2, 4}; + MatSetValues(mat, 3, idxs, 3, idxs, values, ADD_VALUES); +} + """ + kernel = op2.Kernel(c_kernel, "mat_inc") + + # create a tiny 5x5 sparse matrix + petsc_mat = PETSc.Mat().create() + petsc_mat.setSizes(5) + petsc_mat.setUp() + petsc_mat.setValues([0, 2, 4], [0, 2, 4], np.zeros((3, 3), dtype=PETSc.ScalarType)) + petsc_mat.assemble() + + arg = op2.PassthroughArg(op2.OpaqueType("Mat"), petsc_mat.handle) + op2.par_loop(kernel, iterset, arg) + petsc_mat.assemble() + + assert np.allclose( + petsc_mat.getValues(range(5), range(5)), + [ + [10, 0, 20, 0, 30], + [0]*5, + [40, 0, 50, 0, 60], + [0]*5, + [70, 0, 80, 0, 90], + ] + ) + if __name__ == '__main__': import os