From 8bc134949740d6bd18e7bd3fb8379aabe160efac Mon Sep 17 00:00:00 2001 From: Alex McCaskey Date: Thu, 17 Oct 2024 10:06:05 -0400 Subject: [PATCH] Enable Python/C++ interop via exposed JIT functionality (#2214) * Enable C++ interop with user Python kernels. Signed-off-by: Alex McCaskey --- cmake/Modules/CMakeLists.txt | 1 + cmake/Modules/CUDAQConfig.cmake | 3 + cmake/Modules/CUDAQPythonInteropConfig.cmake | 13 + include/cudaq/Optimizer/Transforms/Passes.h | 5 +- .../Transforms/PySynthCallableBlockArgs.cpp | 100 ++++++-- python/CMakeLists.txt | 2 + python/cudaq/kernel/analysis.py | 15 +- python/cudaq/kernel/ast_bridge.py | 54 ++++ python/cudaq/kernel/kernel_decorator.py | 45 +++- python/cudaq/kernel/utils.py | 6 + python/extension/CMakeLists.txt | 1 + python/extension/CUDAQuantumExtension.cpp | 40 +++ .../cudaq/platform/py_alt_launch_kernel.cpp | 72 +++++- python/runtime/interop/CMakeLists.txt | 23 ++ python/runtime/interop/PythonCppInterop.cpp | 100 ++++++++ python/runtime/interop/PythonCppInterop.h | 169 +++++++++++++ python/runtime/utils/PyRemoteRESTQPU.cpp | 4 +- python/tests/CMakeLists.txt | 3 + python/tests/interop/CMakeLists.txt | 18 ++ .../tests/interop/quantum_lib/CMakeLists.txt | 13 + .../tests/interop/quantum_lib/quantum_lib.cpp | 32 +++ .../tests/interop/quantum_lib/quantum_lib.h | 22 ++ .../test_cpp_quantum_algorithm_module.cpp | 52 ++++ python/tests/interop/test_interop.py | 235 ++++++++++++++++++ runtime/cudaq.h | 21 +- runtime/cudaq/cudaq.cpp | 17 +- 26 files changed, 1025 insertions(+), 41 deletions(-) create mode 100644 cmake/Modules/CUDAQPythonInteropConfig.cmake create mode 100644 python/runtime/interop/CMakeLists.txt create mode 100644 python/runtime/interop/PythonCppInterop.cpp create mode 100644 python/runtime/interop/PythonCppInterop.h create mode 100644 python/tests/interop/CMakeLists.txt create mode 100644 python/tests/interop/quantum_lib/CMakeLists.txt create mode 100644 python/tests/interop/quantum_lib/quantum_lib.cpp create mode 100644 python/tests/interop/quantum_lib/quantum_lib.h create mode 100644 python/tests/interop/test_cpp_quantum_algorithm_module.cpp create mode 100644 python/tests/interop/test_interop.py diff --git a/cmake/Modules/CMakeLists.txt b/cmake/Modules/CMakeLists.txt index e49504f2ce3..ddd3e0aa76f 100644 --- a/cmake/Modules/CMakeLists.txt +++ b/cmake/Modules/CMakeLists.txt @@ -14,6 +14,7 @@ set(CONFIG_FILES CUDAQConfig.cmake CUDAQEnsmallenConfig.cmake CUDAQPlatformDefaultConfig.cmake + CUDAQPythonInteropConfig.cmake ) set(LANG_FILES CMakeCUDAQCompiler.cmake.in diff --git a/cmake/Modules/CUDAQConfig.cmake b/cmake/Modules/CUDAQConfig.cmake index 507cadeaead..eadda081d04 100644 --- a/cmake/Modules/CUDAQConfig.cmake +++ b/cmake/Modules/CUDAQConfig.cmake @@ -29,6 +29,9 @@ find_dependency(CUDAQNlopt REQUIRED) set (CUDAQEnsmallen_DIR "${CUDAQ_CMAKE_DIR}") find_dependency(CUDAQEnsmallen REQUIRED) +set (CUDAQPythonInterop_DIR "${CUDAQ_CMAKE_DIR}") +find_dependency(CUDAQPythonInterop REQUIRED) + get_filename_component(PARENT_DIRECTORY ${CUDAQ_CMAKE_DIR} DIRECTORY) get_filename_component(CUDAQ_LIBRARY_DIR ${PARENT_DIRECTORY} DIRECTORY) get_filename_component(CUDAQ_INSTALL_DIR ${CUDAQ_LIBRARY_DIR} DIRECTORY) diff --git a/cmake/Modules/CUDAQPythonInteropConfig.cmake b/cmake/Modules/CUDAQPythonInteropConfig.cmake new file mode 100644 index 00000000000..bad0a26b1dc --- /dev/null +++ b/cmake/Modules/CUDAQPythonInteropConfig.cmake @@ -0,0 +1,13 @@ +# ============================================================================ # +# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. # +# All rights reserved. # +# # +# This source code and the accompanying materials are made available under # +# the terms of the Apache License 2.0 which accompanies this distribution. # +# ============================================================================ # + +get_filename_component(CUDAQ_PYTHONINTEROP_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) + +if(NOT TARGET cudaq::cudaq-python-interop) + include("${CUDAQ_PYTHONINTEROP_CMAKE_DIR}/CUDAQPythonInteropTargets.cmake") +endif() diff --git a/include/cudaq/Optimizer/Transforms/Passes.h b/include/cudaq/Optimizer/Transforms/Passes.h index 6274a8de291..cf39c803d64 100644 --- a/include/cudaq/Optimizer/Transforms/Passes.h +++ b/include/cudaq/Optimizer/Transforms/Passes.h @@ -46,9 +46,10 @@ std::unique_ptr createRaiseToAffinePass(); std::unique_ptr createUnwindLoweringPass(); std::unique_ptr -createPySynthCallableBlockArgs(const std::vector &); +createPySynthCallableBlockArgs(const llvm::SmallVector &, + bool removeBlockArg = false); inline std::unique_ptr createPySynthCallableBlockArgs() { - return createPySynthCallableBlockArgs({}); + return createPySynthCallableBlockArgs({}, false); } /// Helper function to build an argument synthesis pass. The names of the diff --git a/lib/Optimizer/Transforms/PySynthCallableBlockArgs.cpp b/lib/Optimizer/Transforms/PySynthCallableBlockArgs.cpp index 6778f8658d2..65b4c48c7e1 100644 --- a/lib/Optimizer/Transforms/PySynthCallableBlockArgs.cpp +++ b/lib/Optimizer/Transforms/PySynthCallableBlockArgs.cpp @@ -7,6 +7,7 @@ ******************************************************************************/ #include "PassDetails.h" +#include "cudaq/Optimizer/Builder/Runtime.h" #include "cudaq/Optimizer/Dialect/CC/CCOps.h" #include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h" #include "cudaq/Optimizer/Transforms/Passes.h" @@ -22,12 +23,13 @@ namespace { class ReplaceCallIndirect : public OpConversionPattern { public: - const std::vector &names; - const std::map &blockArgToNameMap; + const SmallVector &names; + // const llvm::DenseMap& blockArgToNameMap; + llvm::DenseMap &blockArgToNameMap; ReplaceCallIndirect(MLIRContext *ctx, - const std::vector &functionNames, - const std::map &map) + const SmallVector &functionNames, + llvm::DenseMap &map) : OpConversionPattern(ctx), names(functionNames), blockArgToNameMap(map) {} @@ -41,13 +43,11 @@ class ReplaceCallIndirect : public OpConversionPattern { if (auto blockArg = dyn_cast(ccCallableFunc.getOperand())) { auto argIdx = blockArg.getArgNumber(); - auto replacementName = names[blockArgToNameMap.at(argIdx)]; + auto replacementName = names[blockArgToNameMap[argIdx]]; auto replacement = module.lookupSymbol( - "__nvqpp__mlirgen__" + replacementName); - if (!replacement) { - op.emitError("Invalid replacement function " + replacementName); + cudaq::runtime::cudaqGenPrefixName + replacementName.str()); + if (!replacement) return failure(); - } rewriter.replaceOpWithNewOp(op, replacement, adaptor.getCalleeOperands()); @@ -59,13 +59,46 @@ class ReplaceCallIndirect : public OpConversionPattern { } }; +class ReplaceCallCallable + : public OpConversionPattern { +public: + const SmallVector &names; + llvm::DenseMap &blockArgToNameMap; + + ReplaceCallCallable(MLIRContext *ctx, + const SmallVector &functionNames, + llvm::DenseMap &map) + : OpConversionPattern(ctx), + names(functionNames), blockArgToNameMap(map) {} + + LogicalResult + matchAndRewrite(cudaq::cc::CallCallableOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto callableOperand = adaptor.getCallee(); + auto module = op->getParentOp()->getParentOfType(); + if (auto blockArg = dyn_cast(callableOperand)) { + auto argIdx = blockArg.getArgNumber(); + auto replacementName = names[blockArgToNameMap[argIdx]]; + auto replacement = module.lookupSymbol( + cudaq::runtime::cudaqGenPrefixName + replacementName.str()); + if (!replacement) + return failure(); + + rewriter.replaceOpWithNewOp(op, replacement, + adaptor.getArgs()); + return success(); + } + return failure(); + } +}; + class UpdateQuakeApplyOp : public OpConversionPattern { public: - const std::vector &names; - const std::map &blockArgToNameMap; + const SmallVector &names; + llvm::DenseMap &blockArgToNameMap; UpdateQuakeApplyOp(MLIRContext *ctx, - const std::vector &functionNames, - const std::map &map) + const SmallVector &functionNames, + llvm::DenseMap &map) : OpConversionPattern(ctx), names(functionNames), blockArgToNameMap(map) {} @@ -77,13 +110,11 @@ class UpdateQuakeApplyOp : public OpConversionPattern { auto ctx = op.getContext(); if (auto blockArg = dyn_cast(callableOperand)) { auto argIdx = blockArg.getArgNumber(); - auto replacementName = names[blockArgToNameMap.at(argIdx)]; + auto replacementName = names[blockArgToNameMap[argIdx]]; auto replacement = module.lookupSymbol( - "__nvqpp__mlirgen__" + replacementName); - if (!replacement) { - op.emitError("Invalid replacement function " + replacementName); + cudaq::runtime::cudaqGenPrefixName + replacementName.str()); + if (!replacement) return failure(); - } rewriter.replaceOpWithNewOp( op, TypeRange{}, FlatSymbolRefAttr::get(ctx, replacement.getName()), @@ -97,10 +128,13 @@ class UpdateQuakeApplyOp : public OpConversionPattern { class PySynthCallableBlockArgs : public cudaq::opt::PySynthCallableBlockArgsBase< PySynthCallableBlockArgs> { +private: + bool removeBlockArg = false; + public: - std::vector names; - PySynthCallableBlockArgs(const std::vector &_names) - : names(_names) {} + SmallVector names; + PySynthCallableBlockArgs(const SmallVector &_names, bool remove) + : removeBlockArg(remove), names(_names) {} void runOnOperation() override { auto op = getOperation(); @@ -109,7 +143,7 @@ class PySynthCallableBlockArgs std::size_t numCallableBlockArgs = 0; // need to map blockArgIdx -> counter(0,1,2,...) - std::map blockArgToNamesMap; + llvm::DenseMap blockArgToNamesMap; for (std::size_t i = 0, k = 0; auto ty : op.getFunctionType().getInputs()) { if (isa(ty)) { numCallableBlockArgs++; @@ -129,8 +163,9 @@ class PySynthCallableBlockArgs return; } - patterns.insert( - ctx, names, blockArgToNamesMap); + patterns + .insert( + ctx, names, blockArgToNamesMap); ConversionTarget target(*ctx); // We should remove these operations target.addIllegalOp(); @@ -148,11 +183,22 @@ class PySynthCallableBlockArgs "error synthesizing callable functions for python.\n"); signalPassFailure(); } + + if (removeBlockArg) { + auto numArgs = op.getNumArguments(); + BitVector argsToErase(numArgs); + for (std::size_t argIndex = 0; argIndex < numArgs; ++argIndex) + if (isa(op.getArgument(argIndex).getType())) + argsToErase.set(argIndex); + + op.eraseArguments(argsToErase); + } } }; } // namespace -std::unique_ptr cudaq::opt::createPySynthCallableBlockArgs( - const std::vector &names) { - return std::make_unique(names); +std::unique_ptr +cudaq::opt::createPySynthCallableBlockArgs(const SmallVector &names, + bool removeBlockArg) { + return std::make_unique(names, removeBlockArg); } diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 1c72ed519f1..6efc9e82572 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -64,3 +64,5 @@ if(CUDAQ_BUILD_TESTS) message(FATAL_ERROR "CUDA Quantum Python Warning - CUDAQ_BUILD_TESTS=TRUE but can't find numpy or pytest modules required for testing.") endif() endif() + +add_subdirectory(runtime/interop) diff --git a/python/cudaq/kernel/analysis.py b/python/cudaq/kernel/analysis.py index 633c2d730bf..e16a2ccf6dc 100644 --- a/python/cudaq/kernel/analysis.py +++ b/python/cudaq/kernel/analysis.py @@ -10,6 +10,7 @@ from .utils import globalAstRegistry, globalKernelRegistry, mlirTypeFromAnnotation from ..mlir.dialects import cc from ..mlir.ir import * +from ..mlir._mlir_libs._quakeDialects import cudaq_runtime class MidCircuitMeasurementAnalyzer(ast.NodeVisitor): @@ -161,13 +162,23 @@ def visit_Call(self, node): if len(moduleNames): moduleNames.reverse() + if cudaq_runtime.isRegisteredDeviceModule( + '.'.join(moduleNames)): + return + # This will throw if the function / module is invalid - m = importlib.import_module('.'.join(moduleNames)) + try: + m = importlib.import_module('.'.join(moduleNames)) + except: + return + getattr(m, node.func.attr) name = node.func.attr + if name not in globalAstRegistry: raise RuntimeError( - f"{name} is not a valid kernel to call.") + f"{name} is not a valid kernel to call ({'.'.join(moduleNames)})." + ) self.depKernels[name] = globalAstRegistry[name] diff --git a/python/cudaq/kernel/ast_bridge.py b/python/cudaq/kernel/ast_bridge.py index 6597b38f5d8..6f04773c3f2 100644 --- a/python/cudaq/kernel/ast_bridge.py +++ b/python/cudaq/kernel/ast_bridge.py @@ -1315,6 +1315,59 @@ def visit_Call(self, node): # FindDepKernels has found something like this, loaded it, and now we just # want to get the function name and call it. + # First let's check for registered C++ kernels + cppDevModNames = [] + value = node.func.value + if isinstance(value, ast.Name) and value.id != 'cudaq': + cppDevModNames = [node.func.attr, value.id] + else: + while isinstance(value, ast.Attribute): + cppDevModNames.append(value.attr) + value = value.value + if isinstance(value, ast.Name): + cppDevModNames.append(value.id) + break + + devKey = '.'.join(cppDevModNames[::-1]) + + def get_full_module_path(partial_path): + parts = partial_path.split('.') + for module_name, module in sys.modules.items(): + if module_name.endswith(parts[0]): + try: + obj = module + for part in parts[1:]: + obj = getattr(obj, part) + return f"{module_name}.{'.'.join(parts[1:])}" + except AttributeError: + continue + return partial_path + + devKey = get_full_module_path(devKey) + if cudaq_runtime.isRegisteredDeviceModule(devKey): + maybeKernelName = cudaq_runtime.checkRegisteredCppDeviceKernel( + self.module, devKey + '.' + node.func.attr) + if maybeKernelName == None: + maybeKernelName = cudaq_runtime.checkRegisteredCppDeviceKernel( + self.module, devKey) + if maybeKernelName != None: + otherKernel = SymbolTable( + self.module.operation)[maybeKernelName] + fType = otherKernel.type + if len(fType.inputs) != len(node.args): + funcName = node.func.id if hasattr( + node.func, 'id') else node.func.attr + self.emitFatalError( + f"invalid number of arguments passed to callable {funcName} ({len(node.args)} vs required {len(fType.inputs)})", + node) + + [self.visit(arg) for arg in node.args] + values = [self.popValue() for _ in node.args] + values.reverse() + values = [self.ifPointerThenLoad(v) for v in values] + func.CallOp(otherKernel, values) + return + # Start by seeing if we have mod1.mod2.mod3... moduleNames = [] value = node.func.value @@ -1816,6 +1869,7 @@ def bodyBuilder(iterVal): values = [self.popValue() for _ in node.args] values.reverse() + values = [self.ifPointerThenLoad(v) for v in values] func.CallOp(otherKernel, values) return diff --git a/python/cudaq/kernel/kernel_decorator.py b/python/cudaq/kernel/kernel_decorator.py index 2af09bb891e..e95636439eb 100644 --- a/python/cudaq/kernel/kernel_decorator.py +++ b/python/cudaq/kernel/kernel_decorator.py @@ -12,7 +12,7 @@ from typing import Callable from ..mlir.ir import * from ..mlir.passmanager import * -from ..mlir.dialects import quake, cc +from ..mlir.dialects import quake, cc, func from .ast_bridge import compile_to_mlir, PyASTBridge from .utils import mlirTypeFromPyType, nvqppPrefix, mlirTypeToPyType, globalAstRegistry, emitFatalError, emitErrorIfInvalidPauli, globalRegisteredTypes from .analysis import MidCircuitMeasurementAnalyzer, HasReturnNodeVisitor @@ -220,6 +220,49 @@ def compile(self): self.dependentCaptures = extraMetadata[ 'dependent_captures'] if 'dependent_captures' in extraMetadata else None + def merge_kernel(self, otherMod): + """ + Merge the kernel in this PyKernelDecorator (the ModuleOp) with + the provided ModuleOp. + """ + self.compile() + if not isinstance(otherMod, str): + otherMod = str(otherMod) + newMod = cudaq_runtime.mergeExternalMLIR(self.module, otherMod) + # Get the name of the kernel entry point + name = self.name + for op in newMod.body: + if isinstance(op, func.FuncOp): + for attr in op.attributes: + if 'cudaq-entrypoint' == attr.name: + name = op.name.value.replace(nvqppPrefix, '') + break + + return PyKernelDecorator(None, kernelName=name, module=newMod) + + def synthesize_callable_arguments(self, funcNames): + """ + Given this Kernel has callable block arguments, synthesize away these + callable arguments with the in-module FuncOps with given names. The + name at index 0 in the list corresponds to the first callable block + argument, index 1 to the second callable block argument, etc. + """ + self.compile() + cudaq_runtime.synthPyCallable(self.module, funcNames) + # Reset the argument types by removing the Callable + self.argTypes = [ + a for a in self.argTypes if not cc.CallableType.isinstance(a) + ] + + def extract_c_function_pointer(self, name=None): + """ + Return the C function pointer for the function with given name, or + with the name of this kernel if not provided. + """ + self.compile() + return cudaq_runtime.jitAndGetFunctionPointer( + self.module, nvqppPrefix + self.name if name is None else name) + def __str__(self): """ Return the MLIR Module string representation for this kernel. diff --git a/python/cudaq/kernel/utils.py b/python/cudaq/kernel/utils.py index f3c0f1e52be..fe4bcde429e 100644 --- a/python/cudaq/kernel/utils.py +++ b/python/cudaq/kernel/utils.py @@ -397,6 +397,12 @@ def mlirTypeToPyType(argType): if F32Type.isinstance(argType): return np.float32 + if quake.VeqType.isinstance(argType): + return qvector + + if cc.CallableType.isinstance(argType): + return Callable + if ComplexType.isinstance(argType): if F64Type.isinstance(ComplexType(argType).element_type): return complex diff --git a/python/extension/CMakeLists.txt b/python/extension/CMakeLists.txt index de6ba323bdf..8fed51ecaf5 100644 --- a/python/extension/CMakeLists.txt +++ b/python/extension/CMakeLists.txt @@ -86,6 +86,7 @@ declare_mlir_python_extension(CUDAQuantumPythonSources.Extension OptTransforms MLIRPass CUDAQTargetConfigUtil + cudaq-python-interop ) target_include_directories(CUDAQuantumPythonSources.Extension INTERFACE diff --git a/python/extension/CUDAQuantumExtension.cpp b/python/extension/CUDAQuantumExtension.cpp index d8cbfb81aa7..daf29bfe60e 100644 --- a/python/extension/CUDAQuantumExtension.cpp +++ b/python/extension/CUDAQuantumExtension.cpp @@ -34,8 +34,11 @@ #include "utils/OpaqueArguments.h" #include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Parser/Parser.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "runtime/interop/PythonCppInterop.h" + #include #include #include @@ -243,4 +246,41 @@ PYBIND11_MODULE(_quakeDialects, m) { cudaqRuntime.def("isTerminator", [](MlirOperation op) { return unwrap(op)->hasTrait(); }); + + cudaqRuntime.def( + "isRegisteredDeviceModule", + [](const std::string &name) { + return cudaq::python::isRegisteredDeviceModule(name); + }, + "Return true if the input name (mod1.mod2...) is a registered C++ device " + "module."); + + cudaqRuntime.def( + "checkRegisteredCppDeviceKernel", + [](MlirModule mod, + const std::string &moduleName) -> std::optional { + std::tuple ret; + try { + ret = cudaq::python::getDeviceKernel(moduleName); + } catch (...) { + return std::nullopt; + } + + // Take the code for the kernel we found + // and add it to the input module, return + // the func op. + auto [kName, code] = ret; + auto ctx = unwrap(mod).getContext(); + auto moduleB = mlir::parseSourceString(code, ctx); + auto moduleA = unwrap(mod); + moduleB->walk([&moduleA](func::FuncOp op) { + if (!moduleA.lookupSymbol(op.getName())) + moduleA.push_back(op.clone()); + return WalkResult::advance(); + }); + return kName; + }, + "Given a python module name like `mod1.mod2.func`, see if there is a " + "registered C++ quantum kernel. If so, add the kernel to the Module and " + "return its name."); } diff --git a/python/runtime/cudaq/platform/py_alt_launch_kernel.cpp b/python/runtime/cudaq/platform/py_alt_launch_kernel.cpp index 5b6c7b59385..b91627de9fc 100644 --- a/python/runtime/cudaq/platform/py_alt_launch_kernel.cpp +++ b/python/runtime/cudaq/platform/py_alt_launch_kernel.cpp @@ -30,6 +30,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/InitAllPasses.h" +#include "mlir/Parser/Parser.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" #include @@ -95,8 +96,8 @@ jitAndCreateArgs(const std::string &name, MlirModule module, auto cloned = mod.clone(); auto context = cloned.getContext(); PassManager pm(context); - pm.addNestedPass( - cudaq::opt::createPySynthCallableBlockArgs(names)); + pm.addNestedPass(cudaq::opt::createPySynthCallableBlockArgs( + SmallVector(names.begin(), names.end()))); pm.addPass(cudaq::opt::createGenerateDeviceCodeLoader({.jitTime = true})); pm.addPass(cudaq::opt::createGenerateKernelExecution( {.startingArgIdx = startingArgIdx})); @@ -747,5 +748,72 @@ void bindAltLaunchKernel(py::module &mod) { } }, "Remove our pointers to the cudaq states."); + + mod.def( + "mergeExternalMLIR", + [](MlirModule modA, const std::string &modBStr) { + auto ctx = unwrap(modA).getContext(); + auto moduleB = parseSourceString(modBStr, ctx); + auto moduleA = unwrap(modA).clone(); + moduleB->walk([&moduleA](func::FuncOp op) { + if (!moduleA.lookupSymbol(op.getName())) + moduleA.push_back(op.clone()); + return WalkResult::advance(); + }); + return wrap(moduleA); + }, + "Merge the two Modules into a single Module."); + + mod.def( + "synthPyCallable", + [](MlirModule modA, const std::vector &funcNames) { + auto m = unwrap(modA); + auto context = m.getContext(); + PassManager pm(context); + pm.addNestedPass( + cudaq::opt::createPySynthCallableBlockArgs( + SmallVector(funcNames.begin(), funcNames.end()), + true)); + if (failed(pm.run(m))) + throw std::runtime_error( + "cudaq::jit failed to remove callable block arguments."); + + // fix up the mangled name map + DictionaryAttr attr; + m.walk([&](func::FuncOp op) { + if (op->hasAttrOfType("cudaq-entrypoint")) { + auto strAttr = StringAttr::get( + context, op.getName().str() + "_PyKernelEntryPointRewrite"); + attr = DictionaryAttr::get( + context, {NamedAttribute(StringAttr::get(context, op.getName()), + strAttr)}); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (attr) + m->setAttr("quake.mangled_name_map", attr); + }, + "Synthesize away the callable block argument from the entrypoint in modA " + "with the FuncOp of given name."); + + mod.def( + "jitAndGetFunctionPointer", + [](MlirModule mod, const std::string &funcName) { + OpaqueArguments runtimeArgs; + auto noneType = mlir::NoneType::get(unwrap(mod).getContext()); + auto [jit, rawArgs, size, returnOffset] = + jitAndCreateArgs(funcName, mod, runtimeArgs, {}, noneType); + + auto funcPtr = jit->lookup(funcName); + if (!funcPtr) { + throw std::runtime_error( + "cudaq::builder failed to get kernelReg function."); + } + + return py::capsule(*funcPtr); + }, + "JIT compile and return the C function pointer for the FuncOp of given " + "name."); } } // namespace cudaq diff --git a/python/runtime/interop/CMakeLists.txt b/python/runtime/interop/CMakeLists.txt new file mode 100644 index 00000000000..75f4e533e1e --- /dev/null +++ b/python/runtime/interop/CMakeLists.txt @@ -0,0 +1,23 @@ +# ============================================================================ # +# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. # +# All rights reserved. # +# # +# This source code and the accompanying materials are made available under # +# the terms of the Apache License 2.0 which accompanies this distribution. # +# ============================================================================ # +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +add_compile_options(-Wno-attributes) +add_library(cudaq-python-interop SHARED PythonCppInterop.cpp) +target_include_directories(cudaq-python-interop PRIVATE + ${PYTHON_INCLUDE_DIRS} + ${pybind11_INCLUDE_DIRS} +) +target_link_libraries(cudaq-python-interop PRIVATE pybind11::module cudaq) +install (FILES PythonCppInterop.h DESTINATION include/cudaq/python/) + +install(TARGETS cudaq-python-interop EXPORT cudaq-python-interop-targets DESTINATION lib) + +install(EXPORT cudaq-python-interop-targets + FILE CUDAQPythonInteropTargets.cmake + NAMESPACE cudaq:: + DESTINATION lib/cmake/cudaq) diff --git a/python/runtime/interop/PythonCppInterop.cpp b/python/runtime/interop/PythonCppInterop.cpp new file mode 100644 index 00000000000..8b8bb0a891f --- /dev/null +++ b/python/runtime/interop/PythonCppInterop.cpp @@ -0,0 +1,100 @@ +/******************************************************************************* + * Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ +#include "PythonCppInterop.h" +#include "cudaq.h" + +namespace cudaq::python { + +std::string getKernelName(std::string &input) { + size_t pos = 0; + std::string result = ""; + while (true) { + // Find the next occurrence of "func.func @" + size_t start = input.find("func.func @", pos) + 11; + + if (start == std::string::npos) + break; + + // Find the position of the first "(" after "func.func @" + size_t end = input.find("(", start); + + if (end == std::string::npos) + break; + + // Extract the substring + result = input.substr(start, end - start); + + // Check if the substring doesn't contain ".thunk" + if (result.find(".thunk") == std::string::npos) + break; + + // Move the position to continue searching + pos = end; + } + return result; +} + +std::string extractSubstring(const std::string &input, + const std::string &startStr, + const std::string &endStr) { + size_t startPos = input.find(startStr); + if (startPos == std::string::npos) { + return ""; // Start string not found + } + + startPos += startStr.length(); // Move to the end of the start string + size_t endPos = input.find(endStr, startPos); + if (endPos == std::string::npos) { + return ""; // End string not found + } + + return input.substr(startPos, endPos - startPos); +} + +std::tuple +getMLIRCodeAndName(const std::string &name, const std::string mangledArgs) { + auto cppMLIRCode = + cudaq::get_quake(std::remove_cvref_t(name), mangledArgs); + auto kernelName = cudaq::python::getKernelName(cppMLIRCode); + cppMLIRCode = + "module {\nfunc.func @" + kernelName + + extractSubstring(cppMLIRCode, "func.func @" + kernelName, "func.func") + + "\n}"; + return std::make_tuple(kernelName, cppMLIRCode); +} + +/// Map device kernels represented as mod1.mod2...function to their MLIR +/// representation. +static std::unordered_map> + deviceKernelMLIRMap; + +__attribute__((visibility("default"))) void +registerDeviceKernel(const std::string &module, const std::string &name, + const std::string &mangled) { + auto key = module + "." + name; + deviceKernelMLIRMap.insert({key, getMLIRCodeAndName(name, mangled)}); +} + +bool isRegisteredDeviceModule(const std::string &compositeName) { + for (auto &[k, v] : deviceKernelMLIRMap) { + if (k.starts_with(compositeName)) // FIXME is this valid? + return true; + } + + return false; +} + +std::tuple +getDeviceKernel(const std::string &compositeName) { + auto iter = deviceKernelMLIRMap.find(compositeName); + if (iter == deviceKernelMLIRMap.end()) + throw std::runtime_error("Invalid composite name for device kernel map."); + return iter->second; +} + +} // namespace cudaq::python \ No newline at end of file diff --git a/python/runtime/interop/PythonCppInterop.h b/python/runtime/interop/PythonCppInterop.h new file mode 100644 index 00000000000..64c44b2e0ee --- /dev/null +++ b/python/runtime/interop/PythonCppInterop.h @@ -0,0 +1,169 @@ +/******************************************************************************* + * Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ +#pragma once + +#include + +namespace py = pybind11; + +namespace cudaq::python { + +/// @class CppPyKernelDecorator +/// @brief A C++ wrapper for a Python object representing a CUDA-Q kernel. +class CppPyKernelDecorator { +private: + py::object kernel; + +public: + /// @brief Constructor for CppPyKernelDecorator. + /// @param obj A Python object representing a CUDA-Q kernel. + /// @throw std::runtime_error if the object is not a valid CUDA-Q kernel. + CppPyKernelDecorator(py::object obj) : kernel(obj) { + if (!py::hasattr(obj, "compile")) + throw std::runtime_error("Invalid python kernel object passed, must be " + "annotated with cudaq.kernel"); + } + + /// @brief Compiles the kernel. + void compile() { kernel.attr("compile")(); } + + /// @brief Gets the name of the kernel. + /// @return The name of the kernel as a string. + std::string name() const { return kernel.attr("name").cast(); } + + /// @brief Merges the kernel with another module. + /// @param otherModuleStr The string representation of the other module. + /// @return A new CppPyKernelDecorator object representing the merged kernel. + auto merge_kernel(const std::string &otherModuleStr) { + return CppPyKernelDecorator(kernel.attr("merge_kernel")(otherModuleStr)); + } + + /// @brief Synthesizes callable arguments for the kernel. + /// @param name The name of the kernel. + void synthesize_callable_arguments(const std::vector &names) { + kernel.attr("synthesize_callable_arguments")(names); + } + + /// @brief Extracts a C function pointer from the kernel. + /// @tparam `Args` Variadic template parameter for function arguments. + /// @param kernelName The name of the kernel. + /// @return A function pointer to the extracted C function. + template + auto extract_c_function_pointer(const std::string &kernelName) { + auto capsule = kernel.attr("extract_c_function_pointer")(kernelName) + .cast(); + void *ptr = capsule; + void (*entryPointPtr)(Args &&...) = + reinterpret_cast(ptr); + return *entryPointPtr; + } + + /// @brief Gets the Quake representation of the kernel. + /// @return The Quake representation as a string. + std::string get_quake() { + return kernel.attr("__str__")().cast(); + } +}; + +/// @brief Extracts the kernel name from an input MLIR string. +/// @param input The input string containing the kernel name. +/// @return The extracted kernel name. +std::string getKernelName(std::string &input); + +/// @brief Extracts a sub-string from an input string based on start and end +/// delimiters. +/// @param input The input string to extract from. +/// @param startStr The starting delimiter. +/// @param endStr The ending delimiter. +/// @return The extracted sub-string. +std::string extractSubstring(const std::string &input, + const std::string &startStr, + const std::string &endStr); + +/// @brief Retrieves the MLIR code and mangled kernel name for a given +/// user-level kernel name. +/// @param name The name of the kernel. +/// @return A tuple containing the MLIR code and the kernel name. +std::tuple +getMLIRCodeAndName(const std::string &name, const std::string mangled = ""); + +/// @brief Register a C++ device kernel with the given module and name +/// @param module The name of the module containing the kernel +/// @param name The name of the kernel to register +void registerDeviceKernel(const std::string &module, const std::string &name, + const std::string &mangled); + +/// @brief Retrieve the module and name of a registered device kernel +/// @param compositeName The composite name of the kernel (module.name) +/// @return A tuple containing the module name and kernel name +std::tuple +getDeviceKernel(const std::string &compositeName); + +bool isRegisteredDeviceModule(const std::string &compositeName); + +template +constexpr bool is_const_reference_v = + std::is_reference_v && std::is_const_v>; + +template +struct TypeMangler { + static std::string mangle() { + std::string mangledName = typeid(T).name(); + if constexpr (is_const_reference_v) { + mangledName = "RK" + mangledName; + } + return mangledName; + } +}; + +template +std::string getMangledArgsString() { + std::string result; + (result += ... += TypeMangler::mangle()); + + // Remove any namespace cudaq text + std::string search = "N5cudaq"; + std::string replace = ""; + + size_t pos = result.find(search); + while (pos != std::string::npos) { + result.replace(pos, search.length(), replace); + pos = result.find(search, pos + replace.length()); + } + + return result; +} + +/// @brief Add a C++ device kernel that is usable from CUDA-Q Python. +/// @tparam Signature The function signature of the kernel +/// @param m The Python module to add the kernel to +/// @param modName The name of the submodule to add the kernel to +/// @param kernelName The name of the kernel +/// @param docstring The documentation string for the kernel +template +void addDeviceKernelInterop(py::module_ &m, const std::string &modName, + const std::string &kernelName, + const std::string &docstring) { + + auto mangledArgs = getMangledArgsString(); + + // FIXME Maybe Add replacement options (i.e., _pycudaq -> cudaq) + + py::module_ sub; + if (py::hasattr(m, modName.c_str())) + sub = m.attr(modName.c_str()).cast(); + else + sub = m.def_submodule(modName.c_str()); + + sub.def( + kernelName.c_str(), [](Signature...) {}, docstring.c_str()); + cudaq::python::registerDeviceKernel(sub.attr("__name__").cast(), + kernelName, mangledArgs); + return; +} +} // namespace cudaq::python \ No newline at end of file diff --git a/python/runtime/utils/PyRemoteRESTQPU.cpp b/python/runtime/utils/PyRemoteRESTQPU.cpp index 1ec7cd7c09b..ac216eb2ec0 100644 --- a/python/runtime/utils/PyRemoteRESTQPU.cpp +++ b/python/runtime/utils/PyRemoteRESTQPU.cpp @@ -80,8 +80,8 @@ class PyRemoteRESTQPU : public cudaq::BaseRemoteRESTQPU { // specific to python before the rest of the RemoteRESTQPU workflow auto cloned = m_module.clone(); PassManager pm(cloned.getContext()); - pm.addNestedPass( - cudaq::opt::createPySynthCallableBlockArgs(callableNames)); + pm.addNestedPass(cudaq::opt::createPySynthCallableBlockArgs( + SmallVector(callableNames.begin(), callableNames.end()))); cudaq::opt::addAggressiveEarlyInlining(pm); pm.addPass(mlir::createCanonicalizerPass()); pm.addNestedPass( diff --git a/python/tests/CMakeLists.txt b/python/tests/CMakeLists.txt index bfa7ea4c892..d19d71d1d33 100644 --- a/python/tests/CMakeLists.txt +++ b/python/tests/CMakeLists.txt @@ -8,6 +8,9 @@ # Tests that generate MLIR and are run through both pytest and FileCheck. add_subdirectory(mlir) +if (TARGET nvq++) + add_subdirectory(interop) +endif() if (MPI_CXX_FOUND AND CUDA_FOUND) add_subdirectory(parallel) diff --git a/python/tests/interop/CMakeLists.txt b/python/tests/interop/CMakeLists.txt new file mode 100644 index 00000000000..25a185d20c1 --- /dev/null +++ b/python/tests/interop/CMakeLists.txt @@ -0,0 +1,18 @@ +# ============================================================================ # +# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. # +# All rights reserved. # +# # +# This source code and the accompanying materials are made available under # +# the terms of the Apache License 2.0 which accompanies this distribution. # +# ============================================================================ # +set(_origin_prefix "\$ORIGIN") +if(APPLE) + set(_origin_prefix "@loader_path") +endif() +set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_origin_prefix}:${CMAKE_BINARY_DIR}/lib") +add_subdirectory(quantum_lib) +pybind11_add_module(cudaq_test_cpp_algo test_cpp_quantum_algorithm_module.cpp) +target_link_libraries(cudaq_test_cpp_algo PRIVATE cudaq quantum_lib cudaq-python-interop) +target_include_directories(cudaq_test_cpp_algo PRIVATE ${CMAKE_SOURCE_DIR}/python) +add_dependencies(cudaq_test_cpp_algo nvq++) + diff --git a/python/tests/interop/quantum_lib/CMakeLists.txt b/python/tests/interop/quantum_lib/CMakeLists.txt new file mode 100644 index 00000000000..34fb0241880 --- /dev/null +++ b/python/tests/interop/quantum_lib/CMakeLists.txt @@ -0,0 +1,13 @@ +# ============================================================================ # +# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. # +# All rights reserved. # +# # +# This source code and the accompanying materials are made available under # +# the terms of the Apache License 2.0 which accompanies this distribution. # +# ============================================================================ # + +set(CMAKE_CXX_COMPILER "${CMAKE_BINARY_DIR}/bin/nvq++") +set(CMAKE_CXX_COMPILE_OBJECT " -fPIC --enable-mlir --disable-mlir-links -o -c ") + +# FIXME Error with SHARED, it pulls in all the mlir libraries anyway +add_library(quantum_lib OBJECT quantum_lib.cpp) diff --git a/python/tests/interop/quantum_lib/quantum_lib.cpp b/python/tests/interop/quantum_lib/quantum_lib.cpp new file mode 100644 index 00000000000..c25220e089d --- /dev/null +++ b/python/tests/interop/quantum_lib/quantum_lib.cpp @@ -0,0 +1,32 @@ +/******************************************************************************* + * Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +#include "quantum_lib.h" + +namespace cudaq { +__qpu__ void +entryPoint(const std::function &)> &statePrep) { + cudaq::qvector q(2); + statePrep(q); +} + +__qpu__ void qft(cudaq::qview<> qubits) { + // not really qft, just for testing + h(qubits); +} + +__qpu__ void qft(cudaq::qview<> qubits, const std::vector &x, + std::size_t k) { + h(qubits[k]); + ry(x[0], qubits[k]); +} + +__qpu__ void another(cudaq::qview<> qubits, std::size_t i) { x(qubits[i]); } + +__qpu__ void uccsd(cudaq::qview<> qubits, std::size_t) { h(qubits[0]); } +} // namespace cudaq \ No newline at end of file diff --git a/python/tests/interop/quantum_lib/quantum_lib.h b/python/tests/interop/quantum_lib/quantum_lib.h new file mode 100644 index 00000000000..0de1e17802d --- /dev/null +++ b/python/tests/interop/quantum_lib/quantum_lib.h @@ -0,0 +1,22 @@ +/******************************************************************************* + * Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ +#pragma once + +#include "cudaq/qis/qubit_qis.h" + +namespace cudaq { +void entryPoint(const std::function &)> &statePrep); + +void qft(cudaq::qview<> qubits); +void qft(cudaq::qview<> qubits, const std::vector &x, std::size_t k); + +void another(cudaq::qview<> qubits, std::size_t); + +void uccsd(cudaq::qview<> qubits, std::size_t); + +} // namespace cudaq \ No newline at end of file diff --git a/python/tests/interop/test_cpp_quantum_algorithm_module.cpp b/python/tests/interop/test_cpp_quantum_algorithm_module.cpp new file mode 100644 index 00000000000..7adc19f96a1 --- /dev/null +++ b/python/tests/interop/test_cpp_quantum_algorithm_module.cpp @@ -0,0 +1,52 @@ +/******************************************************************************* + * Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ +#include +#include + +#include "cudaq.h" +#include "cudaq/algorithms/sample.h" +#include "quantum_lib/quantum_lib.h" +#include "runtime/interop/PythonCppInterop.h" + +namespace py = pybind11; + +PYBIND11_MODULE(cudaq_test_cpp_algo, m) { + + m.def("test_cpp_qalgo", [](py::object statePrepIn) { + // Wrap the kernel and compile, will throw + // if not a valid kernel + cudaq::python::CppPyKernelDecorator statePrep(statePrepIn); + statePrep.compile(); + + // Our library exposes an "entryPoint" kernel, get its + // mangled name and MLIR code + auto [kernelName, cppMLIRCode] = + cudaq::python::getMLIRCodeAndName("entryPoint"); + + // Merge the entryPoint kernel into the input stateprep kernel + auto merged = statePrep.merge_kernel(cppMLIRCode); + + // Synthesize away all callable block arguments + merged.synthesize_callable_arguments({statePrep.name()}); + + // Extract the function pointer. + auto entryPointPtr = merged.extract_c_function_pointer(kernelName); + + // Run... + return cudaq::sample(entryPointPtr); + }); + + // Example of how to expose C++ kernels. + cudaq::python::addDeviceKernelInterop>( + m, "qstd", "qft", "(Fake) Quantum Fourier Transform."); + cudaq::python::addDeviceKernelInterop, std::size_t>( + m, "qstd", "another", "Demonstrate we can have multiple ones."); + + cudaq::python::addDeviceKernelInterop, std::size_t>( + m, "qstd", "uccsd", ""); +} \ No newline at end of file diff --git a/python/tests/interop/test_interop.py b/python/tests/interop/test_interop.py new file mode 100644 index 00000000000..78d46e576e9 --- /dev/null +++ b/python/tests/interop/test_interop.py @@ -0,0 +1,235 @@ +# ============================================================================ # +# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. # +# All rights reserved. # +# # +# This source code and the accompanying materials are made available under # +# the terms of the Apache License 2.0 which accompanies this distribution. # +# ============================================================================ # + +import cudaq, pytest + +cudaq_test_cpp_algo = pytest.importorskip('cudaq_test_cpp_algo') + + +@pytest.fixture(autouse=True) +def do_something(): + yield + cudaq.__clearKernelRegistries() + + +def test_call_python_from_cpp(): + + @cudaq.kernel + def statePrep(q: cudaq.qvector): + x(q) + + # The test cpp qalgo just takes the user statePrep and + # applies it to a 2 qubit register, so we should expect 11 in the output. + counts = cudaq_test_cpp_algo.test_cpp_qalgo(statePrep) + counts.dump() + assert len(counts) == 1 and '11' in counts + + +def test_mergeExternal(): + + @cudaq.kernel + def kernel(i: int): + q = cudaq.qvector(i) + h(q[0]) + + kernel.compile() + kernel(10) + + otherMod = '''module attributes {quake.mangled_name_map = {__nvqpp__mlirgen__test = "__nvqpp__mlirgen__test_PyKernelEntryPointRewrite"}} { + func.func @__nvqpp__mlirgen__test() attributes {"cudaq-entrypoint"} { + %0 = quake.alloca !quake.veq<2> + %1 = quake.extract_ref %0[0] : (!quake.veq<2>) -> !quake.ref + quake.h %1 : (!quake.ref) -> () + return + } +}''' + newMod = kernel.merge_kernel(otherMod) + print(newMod) + assert '__nvqpp__mlirgen__test' in str( + newMod) and '__nvqpp__mlirgen__kernel' in str(newMod) + + +def test_synthCallable(): + + @cudaq.kernel + def callee(q: cudaq.qview): + x(q[0]) + x(q[1]) + + callee.compile() + + otherMod = '''module attributes {quake.mangled_name_map = {__nvqpp__mlirgen__caller = "__nvqpp__mlirgen__caller_PyKernelEntryPointRewrite"}} { + func.func @__nvqpp__mlirgen__caller(%arg0: !cc.callable<(!quake.veq) -> ()>) attributes {"cudaq-entrypoint"} { + %0 = quake.alloca !quake.veq<2> + %1 = quake.relax_size %0 : (!quake.veq<2>) -> !quake.veq + %2 = cc.callable_func %arg0 : (!cc.callable<(!quake.veq) -> ()>) -> ((!quake.veq) -> ()) + call_indirect %2(%1) : (!quake.veq) -> () + return + } +}''' + + # Merge the external code with the current pure device kernel + newKernel = callee.merge_kernel(otherMod) + print(newKernel.name, newKernel) + # Synthesize away the callable arg with the pure device kernel + newKernel.synthesize_callable_arguments(['callee']) + print(newKernel) + + counts = cudaq.sample(newKernel) + assert len(counts) == 1 and '11' in counts + + +def test_synthCallableCCCallCallableOp(): + + @cudaq.kernel + def callee(q: cudaq.qview): + x(q[0]) + x(q[1]) + + callee.compile() + + otherMod = '''module attributes {quake.mangled_name_map = {__nvqpp__mlirgen__caller = "__nvqpp__mlirgen__caller_PyKernelEntryPointRewrite"}} { + func.func @__nvqpp__mlirgen__adapt_caller(%arg0: i64, %arg1: !cc.callable<(!quake.veq) -> ()>, %arg2: !cc.stdvec, %arg3: !cc.stdvec, %arg4: !cc.stdvec) attributes {"cudaq-entrypoint", "cudaq-kernel", no_this} { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %0 = quake.alloca !quake.veq[%arg0 : i64] + cc.call_callable %arg1, %0 : (!cc.callable<(!quake.veq) -> ()>, !quake.veq) -> () + %1 = cc.loop while ((%arg5 = %c0_i64) -> (i64)) { + %2 = cc.stdvec_size %arg2 : (!cc.stdvec) -> i64 + %3 = arith.cmpi ult, %arg5, %2 : i64 + cc.condition %3(%arg5 : i64) + } do { + ^bb0(%arg5: i64): + %2 = cc.loop while ((%arg6 = %c0_i64) -> (i64)) { + %3 = cc.stdvec_size %arg4 : (!cc.stdvec) -> i64 + %4 = arith.cmpi ult, %arg6, %3 : i64 + cc.condition %4(%arg6 : i64) + } do { + ^bb0(%arg6: i64): + %3 = cc.stdvec_data %arg2 : (!cc.stdvec) -> !cc.ptr> + %4 = cc.compute_ptr %3[%arg5] : (!cc.ptr>, i64) -> !cc.ptr + %5 = cc.load %4 : !cc.ptr + %6 = cc.stdvec_data %arg3 : (!cc.stdvec) -> !cc.ptr> + %7 = cc.compute_ptr %6[%arg6] : (!cc.ptr>, i64) -> !cc.ptr + %8 = cc.load %7 : !cc.ptr + %9 = arith.mulf %5, %8 : f64 + %10 = cc.stdvec_data %arg4 : (!cc.stdvec) -> !cc.ptr> + %11 = cc.compute_ptr %10[%arg6] : (!cc.ptr>, i64) -> !cc.ptr + %12 = cc.load %11 : !cc.ptr + quake.exp_pauli %9, %0, %12 : (f64, !quake.veq, !cc.charspan) -> () + cc.continue %arg6 : i64 + } step { + ^bb0(%arg6: i64): + %3 = arith.addi %arg6, %c1_i64 : i64 + cc.continue %3 : i64 + } + cc.continue %arg5 : i64 + } step { + ^bb0(%arg5: i64): + %2 = arith.addi %arg5, %c1_i64 : i64 + cc.continue %2 : i64 + } + return + } +}''' + + # Merge the external code with the current pure device kernel + newKernel = callee.merge_kernel(otherMod) + print(newKernel) + # Synthesize away the callable arg with the pure device kernel + newKernel.synthesize_callable_arguments(['callee']) + print(newKernel) + assert '!cc.callable' not in str(newKernel) + + +def testSynthTwoArgs(): + + from typing import Callable + + @cudaq.kernel + def kernel22(k: Callable[[cudaq.qview], None], j: Callable[[cudaq.qview], + None]): + q = cudaq.qvector(2) + k(q) + j(q) + + @cudaq.kernel + def callee0(q: cudaq.qview): + x(q) + + @cudaq.kernel + def callee1(q: cudaq.qview): + x(q) + + callees = callee0.merge_kernel(callee1) + print(callees) + merged = callees.merge_kernel(kernel22) + print(merged) + + merged.synthesize_callable_arguments(['callee0', 'callee1']) + + print(merged) + counts = cudaq.sample(merged) + counts.dump() + assert '00' in counts and len(counts) == 1 + + +def test_cpp_kernel_from_python_0(): + + from cudaq_test_cpp_algo import qstd + + @cudaq.kernel + def callQftAndAnother(): + q = cudaq.qvector(4) + qstd.qft(q) + h(q) + qstd.another(q, 2) + + callQftAndAnother() + + counts = cudaq.sample(callQftAndAnother) + counts.dump() + assert len(counts) == 1 and '0010' in counts + + +def test_cpp_kernel_from_python_1(): + + @cudaq.kernel + def callQftAndAnother(): + q = cudaq.qvector(4) + cudaq_test_cpp_algo.qstd.qft(q) + h(q) + cudaq_test_cpp_algo.qstd.another(q, 2) + + callQftAndAnother() + + counts = cudaq.sample(callQftAndAnother) + counts.dump() + assert len(counts) == 1 and '0010' in counts + + +def test_cpp_kernel_from_python_2(): + + @cudaq.kernel + def callUCCSD(): + q = cudaq.qvector(4) + cudaq_test_cpp_algo.qstd.uccsd(q, 2) + + callUCCSD() + +def test_capture(): + @cudaq.kernel + def takesCapture(s : int): + pass + + spin = 0 + + @cudaq.kernel(verbose=True) + def entry(): + takesCapture(spin) + entry.compile() \ No newline at end of file diff --git a/runtime/cudaq.h b/runtime/cudaq.h index ffe5ea6fec2..355cd52e6e4 100644 --- a/runtime/cudaq.h +++ b/runtime/cudaq.h @@ -32,13 +32,22 @@ extern bool globalFalse; } // namespace __internal__ /// @brief Given a string kernel name, return the corresponding Quake code -// This will throw if the kernel name is unknown to the quake code registry. +/// This will throw if the kernel name is unknown to the quake code registry. std::string get_quake_by_name(const std::string &kernelName); + +/// @brief Given a string kernel name, return the corresponding Quake code. +/// This overload allows one to specify the known mangled arguments string +/// in order to disambiguate overloaded kernel names. +/// This will throw if the kernel name is unknown to the quake code registry. +std::string get_quake_by_name(const std::string &kernelName, + std::optional knownMangledArgs); + /// @brief Given a string kernel name, return the corresponding Quake code. // If `throwException` is set, it will throw if the kernel name is unknown to // the quake code registry. Otherwise, return an empty string in that case. -std::string get_quake_by_name(const std::string &kernelName, - bool throwException); +std::string +get_quake_by_name(const std::string &kernelName, bool throwException, + std::optional knownMangledArgs = std::nullopt); // Simple test to see if the QuantumKernel template // type is a `cudaq::builder` with `operator()(Args...)` @@ -196,6 +205,12 @@ inline std::string get_quake(std::string &&functionName) { return get_quake_by_name(get_kernel_function_name(std::move(functionName))); } +inline std::string get_quake(std::string &&functionName, + const std::string &knownMangledArgs) { + return get_quake_by_name(get_kernel_function_name(std::move(functionName)), + knownMangledArgs); +} + typedef std::size_t (*KernelArgsCreator)(void **, void **); KernelArgsCreator getArgsCreator(const std::string &kernelName); diff --git a/runtime/cudaq/cudaq.cpp b/runtime/cudaq/cudaq.cpp index 9373d051397..10ecc3b914a 100644 --- a/runtime/cudaq/cudaq.cpp +++ b/runtime/cudaq/cudaq.cpp @@ -321,7 +321,8 @@ KernelArgsCreator getArgsCreator(const std::string &kernelName) { } std::string get_quake_by_name(const std::string &kernelName, - bool throwException) { + bool throwException, + std::optional knownMangledArgs) { // A prefix name has a '.' before the C++ mangled name suffix. auto kernelNamePrefix = kernelName + '.'; @@ -342,7 +343,14 @@ std::string get_quake_by_name(const std::string &kernelName, throw std::runtime_error("Quake code for '" + kernelName + "' has multiple matches.\n"); } else { - result = pair.second; + if (!knownMangledArgs.has_value()) + result = pair.second; + else { + if (pair.first.ends_with(knownMangledArgs.value())) { + result = pair.second; + break; + } + } } } } @@ -359,6 +367,11 @@ std::string get_quake_by_name(const std::string &kernelName) { return get_quake_by_name(kernelName, true); } +std::string get_quake_by_name(const std::string &kernelName, + std::optional knownMangledArgs) { + return get_quake_by_name(kernelName, true, knownMangledArgs); +} + bool kernelHasConditionalFeedback(const std::string &kernelName) { auto quakeCode = get_quake_by_name(kernelName, false); return !quakeCode.empty() &&