diff --git a/cuda_core/cuda/core/experimental/_program.py b/cuda_core/cuda/core/experimental/_program.py index 0dac79eb..d947cfd2 100644 --- a/cuda_core/cuda/core/experimental/_program.py +++ b/cuda_core/cuda/core/experimental/_program.py @@ -5,6 +5,7 @@ import weakref from cuda import nvrtc +from cuda.core.experimental._linker import Linker, LinkerOptions from cuda.core.experimental._module import ObjectCode from cuda.core.experimental._utils import handle_return @@ -27,7 +28,7 @@ class Program: """ class _MembersNeededForFinalize: - __slots__ = ("handle",) + __slots__ = "handle" def __init__(self, program_obj, handle): self.handle = handle @@ -38,26 +39,37 @@ def close(self): handle_return(nvrtc.nvrtcDestroyProgram(self.handle)) self.handle = None - __slots__ = ("__weakref__", "_mnff", "_backend") - _supported_code_type = ("c++",) + __slots__ = ("__weakref__", "_mnff", "_backend", "_linker") + _supported_code_type = ("c++", "ptx") _supported_target_type = ("ptx", "cubin", "ltoir") def __init__(self, code, code_type): self._mnff = Program._MembersNeededForFinalize(self, None) + code_type = code_type.lower() if code_type not in self._supported_code_type: raise NotImplementedError - if code_type.lower() == "c++": + if code_type == "c++": if not isinstance(code, str): raise TypeError # TODO: support pre-loaded headers & include names # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), b"", 0, [], [])) self._backend = "nvrtc" + + elif code_type == "ptx": + if not isinstance(code, str): + raise TypeError + # TODO: support pre-loaded headers & include names + # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved + self._linker = Linker(ObjectCode(code.encode(), code_type), options=LinkerOptions(arch="sm_89")) + self._backend = "linker" else: raise NotImplementedError + print(self._backend) + def close(self): """Destroy this program.""" self._mnff.close() @@ -122,6 +134,9 @@ def compile(self, target_type, options=(), name_expressions=(), logs=None): return ObjectCode(data, target_type, symbol_mapping=symbol_mapping) + if self._backend == "linker": + return self._linker.link(target_type) + @property def backend(self): """Return the backend type string associated with this program.""" diff --git a/cuda_core/tests/test_program.py b/cuda_core/tests/test_program.py index cca01af5..01ab2724 100644 --- a/cuda_core/tests/test_program.py +++ b/cuda_core/tests/test_program.py @@ -38,11 +38,16 @@ def test_program_compile_valid_target_type(): code = 'extern "C" __global__ void my_kernel() {}' program = Program(code, "c++") arch = "".join(str(i) for i in Device().compute_capability) - object_code = program.compile("ptx", options=(f"-arch=compute_{arch}",)) - print(object_code._module.decode()) - kernel = object_code.get_kernel("my_kernel") - assert isinstance(object_code, ObjectCode) - assert isinstance(kernel, Kernel) + ptx_object_code = program.compile("ptx", options=(f"-arch=compute_{arch}",)) + print(ptx_object_code._module.decode()) + program = Program(ptx_object_code._module.decode(), "ptx") + cubin_object_code = program.compile("cubin", options=(f"-arch=compute_{arch}",)) + ptx_kernel = ptx_object_code.get_kernel("my_kernel") + cubin_kernel = cubin_object_code.get_kernel("my_kernel") + assert isinstance(ptx_object_code, ObjectCode) + assert isinstance(cubin_object_code, ObjectCode) + assert isinstance(ptx_kernel, Kernel) + assert isinstance(cubin_kernel, Kernel) def test_program_compile_invalid_target_type():