From 9b4f6561109a937d0551234385b5d82bc37e08a0 Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Wed, 18 Dec 2024 16:12:09 +0000 Subject: [PATCH 1/4] Untangle compilation process --- firedrake/function.py | 5 ++-- firedrake/mesh.py | 6 ++--- pyop2/compilation.py | 57 ++++++++++-------------------------------- pyop2/global_kernel.py | 24 +++++++++--------- 4 files changed, 31 insertions(+), 61 deletions(-) diff --git a/firedrake/function.py b/firedrake/function.py index da4d264971..f88075c480 100644 --- a/firedrake/function.py +++ b/firedrake/function.py @@ -775,8 +775,8 @@ def make_c_evaluate(function, c_name="evaluate", ldargs=None, tolerance=None): libspatialindex_so = Path(rtree.core.rt._name).absolute() lsi_runpath = f"-Wl,-rpath,{libspatialindex_so.parent}" ldargs += [str(libspatialindex_so), lsi_runpath] - return compilation.load( - src, "c", c_name, + dll = compilation.load( + src, "c", cppargs=[ f"-I{path.dirname(__file__)}", f"-I{sys.prefix}/include", @@ -785,3 +785,4 @@ def make_c_evaluate(function, c_name="evaluate", ldargs=None, tolerance=None): ldargs=ldargs, comm=function.comm ) + return getattr(dll, c_name) diff --git a/firedrake/mesh.py b/firedrake/mesh.py index bef4d38bf3..90fa91e1f3 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -2681,8 +2681,8 @@ def _c_locator(self, tolerance=None): libspatialindex_so = Path(rtree.core.rt._name).absolute() lsi_runpath = f"-Wl,-rpath,{libspatialindex_so.parent}" - locator = compilation.load( - src, "c", "locator", + dll = compilation.load( + src, "c", cppargs=[ f"-I{os.path.dirname(__file__)}", f"-I{sys.prefix}/include", @@ -2696,7 +2696,7 @@ def _c_locator(self, tolerance=None): ], comm=self.comm ) - + locator = getattr(dll, "locator") locator.argtypes = [ctypes.POINTER(function._CFunction), ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double), diff --git a/pyop2/compilation.py b/pyop2/compilation.py index d28c945188..e9123fc891 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -56,7 +56,6 @@ from pyop2.logger import warning, debug, progress, INFO from pyop2.exceptions import CompilationError from pyop2.utils import get_petsc_variables -import pyop2.global_kernel from petsc4py import PETSc @@ -424,38 +423,16 @@ def load_hashkey(*args, **kwargs): @mpi.collective @memory_cache(hashkey=load_hashkey) @PETSc.Log.EventDecorator() -def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(), - argtypes=None, restype=None, comm=None): +def load(code, extension, cppargs=(), ldargs=(), comm=None): """Build a shared library and return a function pointer from it. - :arg jitmodule: The JIT Module which can generate the code to compile, or - the string representing the source code. + :arg code: The code to compile. :arg extension: extension of the source file (c, cpp) - :arg fn_name: The name of the function to return from the resulting library :arg cppargs: A tuple of arguments to the C compiler (optional) :arg ldargs: A tuple of arguments to the linker (optional) - :arg argtypes: A list of ctypes argument types matching the arguments of - the returned function (optional, pass ``None`` for ``void``). This is - only used when string is passed in instead of JITModule. - :arg restype: The return type of the function (optional, pass - ``None`` for ``void``). :kwarg comm: Optional communicator to compile the code on (only rank 0 compiles code) (defaults to pyop2.mpi.COMM_WORLD). """ - if isinstance(jitmodule, str): - class StrCode(object): - def __init__(self, code, argtypes): - self.code_to_compile = code - self.cache_key = (None, code) # We peel off the first - # entry, since for a jitmodule, it's a process-local - # cache key - self.argtypes = argtypes - code = StrCode(jitmodule, argtypes) - elif isinstance(jitmodule, pyop2.global_kernel.GlobalKernel): - code = jitmodule - else: - raise ValueError("Don't know how to compile code of type %r" % type(jitmodule)) - global _compiler if _compiler: # Use the global compiler if it has been set @@ -475,15 +452,7 @@ def __init__(self, code, argtypes): # This call is cached on disk so_name = make_so(compiler_instance, code, extension, comm) # This call might be cached in memory by the OS (system dependent) - dll = ctypes.CDLL(so_name) - - if isinstance(jitmodule, pyop2.global_kernel.GlobalKernel): - _add_profiling_events(dll, code.local_kernel.events) - - fn = getattr(dll, fn_name) - fn.argtypes = code.argtypes - fn.restype = restype - return fn + return ctypes.CDLL(so_name) def expandWl(ldflags): @@ -519,27 +488,27 @@ def setdefault(self, key, default=None): return self[key] -def _make_so_hashkey(compiler, jitmodule, extension, comm): +def _make_so_hashkey(compiler, code, extension, comm): if extension == "cpp": exe = compiler.cxx compiler_flags = compiler.cxxflags else: exe = compiler.cc compiler_flags = compiler.cflags - return (compiler, exe, compiler_flags, compiler.ld, compiler.ldflags, jitmodule.cache_key) + return (compiler, code, exe, compiler_flags, compiler.ld, compiler.ldflags) -def check_source_hashes(compiler, jitmodule, extension, comm): +def check_source_hashes(compiler, code, extension, comm): """A check to see whether code generated on all ranks is identical. :arg compiler: The compiler to use to create the shared library. - :arg jitmodule: The JIT Module which can generate the code to compile. + :arg code: The code to compile. :arg filename: The filename of the library to create. :arg extension: extension of the source file (c, cpp). :arg comm: Communicator over which to perform compilation. """ # Reconstruct hash from filename - hashval = _as_hexdigest(_make_so_hashkey(compiler, jitmodule, extension, comm)) + hashval = _as_hexdigest(_make_so_hashkey(compiler, code, extension, comm)) with mpi.temp_internal_comm(comm) as icomm: matching = icomm.allreduce(hashval, op=_check_op) if matching != hashval: @@ -550,7 +519,7 @@ def check_source_hashes(compiler, jitmodule, extension, comm): output.mkdir(parents=True, exist_ok=True) icomm.barrier() with open(srcfile, "w") as fh: - fh.write(jitmodule.code_to_compile) + fh.write(code) icomm.barrier() raise CompilationError(f"Generated code differs across ranks (see output in {output})") @@ -561,11 +530,11 @@ def check_source_hashes(compiler, jitmodule, extension, comm): cache_factory=lambda: CompilerDiskAccess(configuration['cache_dir'], extension=".so") ) @PETSc.Log.EventDecorator() -def make_so(compiler, jitmodule, extension, comm, filename=None): +def make_so(compiler, code, extension, comm, filename=None): """Build a shared library and load it :arg compiler: The compiler to use to create the shared library. - :arg jitmodule: The JIT Module which can generate the code to compile. + :arg code: The code to compile. :arg filename: The filename of the library to create. :arg extension: extension of the source file (c, cpp). :arg comm: Communicator over which to perform compilation. @@ -605,7 +574,7 @@ def make_so(compiler, jitmodule, extension, comm, filename=None): with progress(INFO, 'Compiling wrapper'): # Write source code to disk with open(cname, "w") as fh: - fh.write(jitmodule.code_to_compile) + fh.write(code) os.close(descriptor) if not compiler.ld: @@ -650,7 +619,7 @@ def _run(cc, logfile, errfile, step="Compilation", filemode="w"): """)) -def _add_profiling_events(dll, events): +def add_profiling_events(dll, events): """ If PyOP2 is in profiling mode, events are attached to dll to profile the local linear algebra calls. The event is generated here in python and then set in the shared library, diff --git a/pyop2/global_kernel.py b/pyop2/global_kernel.py index ae13dc1c59..7edfed0771 100644 --- a/pyop2/global_kernel.py +++ b/pyop2/global_kernel.py @@ -11,7 +11,7 @@ from petsc4py import PETSc from pyop2 import mpi -from pyop2.compilation import load +from pyop2.compilation import add_profiling_events, load from pyop2.configuration import configuration from pyop2.datatypes import IntType, as_ctypes from pyop2.types import IterationRegion, Constant, READ @@ -366,8 +366,11 @@ def code_to_compile(self): """Return the C/C++ source code as a string.""" from pyop2.codegen.rep2loopy import generate - wrapper = generate(self.builder) - code = lp.generate_code_v2(wrapper) + with PETSc.Log.Event("GlobalKernel: generate loopy"): + wrapper = generate(self.builder) + + with PETSc.Log.Event("GlobalKernel: generate device code"): + code = lp.generate_code_v2(wrapper) if self.local_kernel.cpp: from loopy.codegen.result import process_preambles @@ -397,15 +400,12 @@ def compile(self, comm): + tuple(self.local_kernel.ldargs) ) - return load( - self, - extension, - self.name, - cppargs=cppargs, - ldargs=ldargs, - restype=ctypes.c_int, - comm=comm - ) + dll = load(self.code_to_compile, extension, cppargs=cppargs, ldargs=ldargs, comm=comm) + add_profiling_events(dll, self.local_kernel.events) + fn = getattr(dll, self.name) + fn.argtypes = self.argtypes + fn.restype = ctypes.c_int + return fn @cached_property def argtypes(self): From 8d2c32c8bf036f4c021b08e035d6d46c3e3aa57d Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Wed, 18 Dec 2024 16:48:41 +0000 Subject: [PATCH 2/4] fix pyop2 tests --- tests/pyop2/test_caching.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/pyop2/test_caching.py b/tests/pyop2/test_caching.py index 1298991b3e..ba85b12e2b 100644 --- a/tests/pyop2/test_caching.py +++ b/tests/pyop2/test_caching.py @@ -785,7 +785,8 @@ def test_writing_large_so(): if COMM_WORLD.rank == 1: os.remove("big.c") - fn = load(program, "c", "big", argtypes=(ctypes.c_voidp,), comm=COMM_WORLD) + dll = load(program, "c", comm=COMM_WORLD) + fn = getattr(dll, "big") assert fn is not None @@ -800,7 +801,8 @@ def test_two_comms_compile_the_same_code(): } """) - fn = load(code, "c", "noop", argtypes=(), comm=COMM_WORLD) + dll = load(code, "c", comm=COMM_WORLD) + fn = getattr(dll, "noop") assert fn is not None From b07d1dc6cac71b85bbc34cce0d7c5e90cf05ef16 Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Thu, 19 Dec 2024 17:25:23 +0000 Subject: [PATCH 3/4] fixups --- firedrake/preconditioners/fdm.py | 8 ++++++-- firedrake/preconditioners/patch.py | 13 +++++++------ firedrake/supermeshing.py | 9 +++++---- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/firedrake/preconditioners/fdm.py b/firedrake/preconditioners/fdm.py index 1696801cb4..186fbbc28c 100644 --- a/firedrake/preconditioners/fdm.py +++ b/firedrake/preconditioners/fdm.py @@ -1835,13 +1835,17 @@ def setSubMatCSR(comm, triu=False): return cache.setdefault(key, SparseAssembler.load_setSubMatCSR(comm, triu)) @staticmethod - def load_c_code(code, name, **kwargs): + def load_c_code(code, name, comm, argtypes, restype): petsc_dir = get_petsc_dir() cppargs = [f"-I{d}/include" for d in petsc_dir] ldargs = ([f"-L{d}/lib" for d in petsc_dir] + [f"-Wl,-rpath,{d}/lib" for d in petsc_dir] + ["-lpetsc", "-lm"]) - return load(code, "c", name, cppargs=cppargs, ldargs=ldargs, **kwargs) + dll = load(code, "c", cppargs=cppargs, ldargs=ldargs, comm=comm) + fn = getattr(dll, name) + fn.argtypes = argtypes + fn.restype = restype + return fn @staticmethod def load_setSubMatCSR(comm, triu=False): diff --git a/firedrake/preconditioners/patch.py b/firedrake/preconditioners/patch.py index 0a7bad5575..5e9d0d4fa0 100644 --- a/firedrake/preconditioners/patch.py +++ b/firedrake/preconditioners/patch.py @@ -505,12 +505,13 @@ def load_c_function(code, name, comm): ldargs = (["-L%s/lib" % d for d in get_petsc_dir()] + ["-Wl,-rpath,%s/lib" % d for d in get_petsc_dir()] + ["-lpetsc", "-lm"]) - return load(code, "c", name, - argtypes=[ctypes.c_voidp, ctypes.c_int, ctypes.c_voidp, - ctypes.c_voidp, ctypes.c_voidp, ctypes.c_int, - ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp], - restype=ctypes.c_int, cppargs=cppargs, ldargs=ldargs, - comm=comm) + dll = load(code, "c", cppargs=cppargs, ldargs=ldargs, comm=comm) + fn = getattr(dll, name) + fn.argtypes = [ctypes.c_voidp, ctypes.c_int, ctypes.c_voidp, + ctypes.c_voidp, ctypes.c_voidp, ctypes.c_int, + ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp] + fn.restype = ctypes.c_int + return fn def make_c_arguments(form, kernel, state, get_map, require_state=False, diff --git a/firedrake/supermeshing.py b/firedrake/supermeshing.py index a1ce2cde17..6af10002d0 100644 --- a/firedrake/supermeshing.py +++ b/firedrake/supermeshing.py @@ -432,14 +432,15 @@ def likely(cell_A): includes = ["-I%s/include" % d for d in dirs] libs = ["-L%s/lib" % d for d in dirs] libs = libs + ["-Wl,-rpath,%s/lib" % d for d in dirs] + ["-lpetsc", "-lsupermesh"] - lib = load( - supermesh_kernel_str, "c", "supermesh_kernel", + dll = load( + supermesh_kernel_str, "c", cppargs=includes, ldargs=libs, - argtypes=[ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp], - restype=ctypes.c_int, comm=mesh_A._comm ) + lib = getattr(dll, "supermesh_kernel") + lib.argtypes = [ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp] + lib.restype = ctypes.c_int ammm(V_A, V_B, likely, node_locations_A, node_locations_B, M_SS, ctypes.addressof(lib), mat) if orig_value_size == 1: From 0f37a37ca6069b8988a5abf09210fd2d5f3e501e Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Thu, 19 Dec 2024 17:28:37 +0000 Subject: [PATCH 4/4] linting --- tests/pyop2/test_caching.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pyop2/test_caching.py b/tests/pyop2/test_caching.py index ba85b12e2b..cfd9e6ce7f 100644 --- a/tests/pyop2/test_caching.py +++ b/tests/pyop2/test_caching.py @@ -31,7 +31,6 @@ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED # OF THE POSSIBILITY OF SUCH DAMAGE. -import ctypes import os import pytest import tempfile