diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index e77eada9..50686c43 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -154,6 +154,8 @@ class Executor::Impl { Stream stream() const { return reinterpret_cast(stream_raw_); } + std::shared_ptr buffer() const { return buffer_; } + std::string plan() const { return plan_json_.dump_pretty(); } void compile(); @@ -934,6 +936,8 @@ int Executor::device_id() const { return impl_->device_id(); } Stream Executor::stream() const { return impl_->stream(); } +std::shared_ptr Executor::buffer() const { return impl_->buffer(); } + std::string Executor::plan() const { return impl_->plan(); } void Executor::compile() { impl_->compile(); } diff --git a/ark/include/ark/executor.hpp b/ark/include/ark/executor.hpp index 14ca8761..02a67cd2 100644 --- a/ark/include/ark/executor.hpp +++ b/ark/include/ark/executor.hpp @@ -15,6 +15,8 @@ namespace ark { using Stream = void *; +class GpuMemory; + /// Convenience class for executing a model. class Executor { public: @@ -31,6 +33,9 @@ class Executor { /// Return the stream of the executor. Stream stream() const; + /// Return the buffer of the executor. + std::shared_ptr buffer() const; + /// Return the plan string. std::string plan() const; diff --git a/examples/tutorial/planner_tutorial.py b/examples/tutorial/planner_tutorial.py index 1f6c3ac5..6153aaf8 100644 --- a/examples/tutorial/planner_tutorial.py +++ b/examples/tutorial/planner_tutorial.py @@ -69,14 +69,13 @@ def perf(): shape = (32, 2048, 2048) - # input = torch.randn(*shape).to("cuda:0") - input = ark.tensor(shape) + input = torch.randn(*shape).to("cuda:0") - output = Softmax()(input) + output = Softmax()(ark.Tensor.from_torch(input)) - # if torch.allclose(eval(output), F.softmax(input, dim=-1), atol=1e-5): - # print("Correct result") - # else: - # print("Incorrect result") + if torch.allclose(eval(output), F.softmax(input, dim=-1), atol=1e-5): + print("Correct result") + else: + print("Incorrect result") print(f"Performance: {(perf() * 1e3):.3f} ms/iter") diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index bd25d01e..2e160f8d 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -21,3 +21,16 @@ pybind11_add_module(ark_py ${BIND_SOURCES}) set_target_properties(ark_py PROPERTIES OUTPUT_NAME _ark_core) target_link_libraries(ark_py PRIVATE ark_static) target_include_directories(ark_py SYSTEM PRIVATE ${DLPACK_INCLUDE_DIRS}) +target_include_directories(ark_py PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../ark) + +if(ARK_USE_CUDA) + target_include_directories(ark_py SYSTEM PRIVATE + ${CUDAToolkit_INCLUDE_DIRS} + ) +endif() + +if(ARK_USE_ROCM) + target_include_directories(ark_py SYSTEM PRIVATE + /opt/rocm/include + ) +endif() diff --git a/python/ark/__init__.py b/python/ark/__init__.py index c20b50b8..68b03ab2 100644 --- a/python/ark/__init__.py +++ b/python/ark/__init__.py @@ -1,12 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import sys import os if os.environ.get("ARK_ROOT", None) is None: os.environ["ARK_ROOT"] = os.path.abspath(os.path.dirname(__file__)) -from . import _ark_core +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +import _ark_core from .model import Model diff --git a/python/ark/data_type.py b/python/ark/data_type.py index 8ab98210..41c4201c 100644 --- a/python/ark/data_type.py +++ b/python/ark/data_type.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import numpy -from . import _ark_core +import _ark_core try: import torch diff --git a/python/ark/init.py b/python/ark/init.py index 32f53079..dbf7c156 100644 --- a/python/ark/init.py +++ b/python/ark/init.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from . import _ark_core +import _ark_core from .model import Model from .runtime import _RuntimeState diff --git a/python/ark/model.py b/python/ark/model.py index 87af88f4..e6208fc1 100644 --- a/python/ark/model.py +++ b/python/ark/model.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from typing import NewType -from ._ark_core import _Model +from _ark_core import _Model _ModelState = NewType("_ModelState", None) diff --git a/python/ark/module.py b/python/ark/module.py index faeeea40..d797da72 100644 --- a/python/ark/module.py +++ b/python/ark/module.py @@ -5,7 +5,7 @@ import numpy as np from typing import Any, Dict, List, Union from .tensor import Tensor, Parameter -from .runtime import Runtime, DefaultPlanner +from .runtime import Runtime, Planner from .ops import tensor from .data_type import DataType @@ -183,7 +183,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: self.built_forward = True with Runtime.get_runtime() as rt: - rt.launch(plan=DefaultPlanner().plan()) + rt.launch(plan=Planner().plan()) for tns, arg in zip(self.forward_input_tensor_args, args): tns.copy(arg) for key, value in self.forward_input_tensor_kwargs.items(): diff --git a/python/ark/ops.py b/python/ark/ops.py index f890e5d1..7d98f51c 100644 --- a/python/ark/ops.py +++ b/python/ark/ops.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import json from typing import Any, Dict, List, Iterable, Union from .tensor import Dims, Tensor, Parameter, NullTensor @@ -13,12 +12,6 @@ def _is_list_or_tuple(obj): return isinstance(obj, list) or isinstance(obj, tuple) -def _config_to_str(config: Union[str, Dict[str, Any]]) -> str: - if isinstance(config, str): - return config - return json.dumps(config) - - def _tensor( shape: Iterable[int], dtype: DataType = fp32, @@ -59,7 +52,6 @@ def add( input: Union[Tensor, float], other: Union[Tensor, float], output: Tensor = NullTensor, - config: Union[str, Dict[str, Any]] = "", name: str = "add", ) -> Union[Tensor, float]: """ @@ -83,14 +75,12 @@ def add( return input + other else: return Tensor( - Model.get_model().copy( - input + other, output._tensor, _config_to_str(config), name - ) + Model.get_model().copy(input + other, output._tensor, name) ) if output is not NullTensor: output = output._tensor return Tensor( - Model.get_model().add(a, b, output, _config_to_str(config), name), + Model.get_model().add(a, b, output, name), runtime_id=input.runtime_id, ) @@ -99,16 +89,13 @@ def cast( input: Tensor, dtype: DataType, output: Tensor = NullTensor, - config: Union[str, Dict[str, Any]] = "", name: str = "cast", ) -> Tensor: """Type casting.""" if output is not NullTensor: output = output._tensor return Tensor( - Model.get_model().cast( - input._tensor, dtype.ctype(), output, _config_to_str(config), name - ), + Model.get_model().cast(input._tensor, dtype.ctype(), output, name), runtime_id=input.runtime_id, ) @@ -130,7 +117,6 @@ def constant( def copy( input: Union[Tensor, float], output: Tensor = NullTensor, - config: Union[str, Dict[str, Any]] = "", name: str = "copy", ) -> Tensor: """Data caopy.""" @@ -139,7 +125,7 @@ def copy( if isinstance(input, Tensor): intput = intput._tensor return Tensor( - Model.get_model().copy(intput, output, _config_to_str(config), name), + Model.get_model().copy(intput, output, name), runtime_id=input.runtime_id, ) @@ -148,7 +134,6 @@ def div( input: Tensor, other: Union[Tensor, float], output: Tensor = NullTensor, - config: Union[str, Dict[str, Any]] = "", name: str = "div", ) -> Tensor: """ @@ -164,9 +149,7 @@ def div( raise ValueError("Tensors must be on the same runtime") other = other._tensor return Tensor( - Model.get_model().div( - input._tensor, other, output, _config_to_str(config), name - ), + Model.get_model().div(input._tensor, other, output, name), runtime_id=input.runtime_id, ) @@ -175,7 +158,6 @@ def embedding( input: Tensor, weight: Tensor, output: Tensor = NullTensor, - config: Union[str, Dict[str, Any]] = "", name: str = "embedding", ) -> Tensor: """Embedding layer.""" @@ -185,7 +167,7 @@ def embedding( output = output._tensor return Tensor( Model.get_model().embedding( - input._tensor, weight._tensor, output, _config_to_str(config), name + input._tensor, weight._tensor, output, name ), runtime_id=input.runtime_id, ) @@ -194,7 +176,6 @@ def embedding( def exp( input: Tensor, output: Tensor = NullTensor, - config: Union[str, Dict[str, Any]] = "", name: str = "exp", ) -> Tensor: """ @@ -205,9 +186,7 @@ def exp( if output is not NullTensor: output = output._tensor return Tensor( - Model.get_model().exp( - input._tensor, output, _config_to_str(config), name - ), + Model.get_model().exp(input._tensor, output, name), runtime_id=input.runtime_id, ) @@ -215,7 +194,6 @@ def exp( def gelu( input: Tensor, output: Tensor = NullTensor, - config: Union[str, Dict[str, Any]] = "", name: str = "gelu", ) -> Tensor: """ @@ -229,9 +207,7 @@ def gelu( if output is not NullTensor: output = output._tensor return Tensor( - Model.get_model().gelu( - input._tensor, output, _config_to_str(config), name - ), + Model.get_model().gelu(input._tensor, output, name), runtime_id=input.runtime_id, ) @@ -263,7 +239,6 @@ def matmul( output: Tensor = NullTensor, transpose_input: bool = False, transpose_other: bool = False, - config: Union[str, Dict[str, Any]] = "", name: str = "matmul", ) -> Tensor: """ @@ -286,7 +261,6 @@ def matmul( output, transpose_input, transpose_other, - _config_to_str(config), name, ), runtime_id=input.runtime_id, @@ -297,7 +271,6 @@ def mul( input: Tensor, other: Union[Tensor, float], output: Tensor = NullTensor, - config: Union[str, Dict[str, Any]] = "", name: str = "mul", ) -> Tensor: """ @@ -313,9 +286,7 @@ def mul( raise ValueError("Tensors must be on the same runtime") other = other._tensor return Tensor( - Model.get_model().mul( - input._tensor, other, output, _config_to_str(config), name - ), + Model.get_model().mul(input._tensor, other, output, name), runtime_id=input.runtime_id, ) @@ -332,7 +303,6 @@ def reduce_max( axis: int, keepdims: bool = True, output: Tensor = NullTensor, - config: Union[str, Dict[str, Any]] = "", name: str = "reduce_max", ) -> Tensor: """ @@ -345,7 +315,7 @@ def reduce_max( output = output._tensor return Tensor( Model.get_model().reduce_max( - input._tensor, axis, keepdims, output, _config_to_str(config), name + input._tensor, axis, keepdims, output, name ), runtime_id=input.runtime_id, ) @@ -356,7 +326,6 @@ def reduce_mean( axis: int, keepdims: bool = True, output: Tensor = NullTensor, - config: Union[str, Dict[str, Any]] = "", name: str = "reduce_mean", ) -> Tensor: """ @@ -369,7 +338,7 @@ def reduce_mean( output = output._tensor return Tensor( Model.get_model().reduce_mean( - input._tensor, axis, keepdims, output, _config_to_str(config), name + input._tensor, axis, keepdims, output, name ), runtime_id=input.runtime_id, ) @@ -380,7 +349,6 @@ def reduce_sum( axis: int, keepdims: bool = True, output: Tensor = NullTensor, - config: Union[str, Dict[str, Any]] = "", name: str = "reduce_sum", ) -> Tensor: """ @@ -395,7 +363,7 @@ def reduce_sum( output = output._tensor return Tensor( Model.get_model().reduce_sum( - input._tensor, axis, keepdims, output, _config_to_str(config), name + input._tensor, axis, keepdims, output, name ), runtime_id=input.runtime_id, ) @@ -404,7 +372,6 @@ def reduce_sum( def relu( input: Tensor, output: Tensor = NullTensor, - config: Union[str, Dict[str, Any]] = "", name: str = "relu", ) -> Tensor: """ @@ -416,9 +383,7 @@ def relu( if output is not NullTensor: output = output._tensor return Tensor( - Model.get_model().relu( - input._tensor, output, _config_to_str(config), name - ), + Model.get_model().relu(input._tensor, output, name), runtime_id=input.runtime_id, ) @@ -457,7 +422,6 @@ def rope( input: Tensor, other: Tensor, output: Tensor = NullTensor, - config: Union[str, Dict[str, Any]] = "", name: str = "rope", ) -> Tensor: """ @@ -470,9 +434,7 @@ def rope( if input.runtime_id != other.runtime_id: raise ValueError("Tensors must be on the same runtime") return Tensor( - Model.get_model().rope( - input._tensor, other._tensor, output, _config_to_str(config), name - ), + Model.get_model().rope(input._tensor, other._tensor, output, name), runtime_id=input.runtime_id, ) @@ -480,7 +442,6 @@ def rope( def rsqrt( input: Tensor, output: Tensor = NullTensor, - config: Union[str, Dict[str, Any]] = "", name: str = "rsqrt", ) -> Tensor: """ @@ -491,9 +452,7 @@ def rsqrt( if output is not NullTensor: output = output._tensor return Tensor( - Model.get_model().rsqrt( - input._tensor, output, _config_to_str(config), name - ), + Model.get_model().rsqrt(input._tensor, output, name), runtime_id=input.runtime_id, ) @@ -521,7 +480,6 @@ def sharding( def sigmoid( input: Tensor, output: Tensor = NullTensor, - config: Union[str, Dict[str, Any]] = "", name: str = "sigmoid", ) -> Tensor: """ @@ -533,9 +491,7 @@ def sigmoid( if output is not NullTensor: output = output._tensor return Tensor( - Model.get_model().sigmoid( - input._tensor, output, _config_to_str(config), name - ), + Model.get_model().sigmoid(input._tensor, output, name), runtime_id=input.runtime_id, ) @@ -543,7 +499,6 @@ def sigmoid( def sqrt( input: Tensor, output: Tensor = NullTensor, - config: Union[str, Dict[str, Any]] = "", name: str = "sqrt", ) -> Tensor: """ @@ -554,9 +509,7 @@ def sqrt( if output is not NullTensor: output = output._tensor return Tensor( - Model.get_model().sqrt( - input._tensor, output, _config_to_str(config), name - ), + Model.get_model().sqrt(input._tensor, output, name), runtime_id=input.runtime_id, ) @@ -565,7 +518,6 @@ def sub( input: Tensor, other: Union[Tensor, float], output: Tensor = NullTensor, - config: Union[str, Dict[str, Any]] = "", name: str = "sub", ) -> Tensor: """ @@ -581,9 +533,7 @@ def sub( raise ValueError("Tensors must be on the same runtime") other = other._tensor return Tensor( - Model.get_model().sub( - input._tensor, other, output, _config_to_str(config), name - ), + Model.get_model().sub(input._tensor, other, output, name), runtime_id=input.runtime_id, ) @@ -613,7 +563,6 @@ def transpose( input: Tensor, perm: Iterable[int], output: Tensor = NullTensor, - config: Union[str, Dict[str, Any]] = "", name: str = "transpose", ) -> Tensor: """ @@ -633,9 +582,7 @@ def transpose( if len(perm) > 4: raise ValueError("Only support perm up to 4 dimensions") return Tensor( - Model.get_model().transpose( - input._tensor, perm, output, _config_to_str(config), name - ), + Model.get_model().transpose(input._tensor, perm, output, name), runtime_id=input.runtime_id, ) @@ -648,7 +595,6 @@ def mean( axis: int, keepdims: bool = True, output: Tensor = NullTensor, - config: Union[str, Dict[str, Any]] = "", name: str = "mean", ) -> Tensor: """Alias of reduce_mean.""" @@ -764,9 +710,10 @@ def all_reduce( "reshape", "identity", "sharding", - "reduce_sum", - "reduce_mean", + "noop", "reduce_max", + "reduce_mean", + "reduce_sum", "layernorm", "softmax", "transpose", diff --git a/python/ark/tensor.py b/python/ark/tensor.py index eed7a425..089d3eae 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -4,7 +4,7 @@ import numpy as np from typing import Callable, List, Union, Type -from ._ark_core import _Dims, _Tensor, _NullTensor +from _ark_core import _Dims, _Tensor, _NullTensor from .data_type import DataType from .runtime import Runtime from .model import Model @@ -102,63 +102,6 @@ def to_numpy( rt.executor.tensor_read(self._tensor, ndarray, stream) return ndarray - def to_torch( - self, tensor: torch.Tensor = None, stream: int = 0 - ) -> torch.Tensor: - """ """ - if _no_torch: - raise ImportError("torch is not available") - rt = Runtime.get_runtime(self.runtime_id) - if not rt.launched(): - raise RuntimeError( - "Tensor is not allocated yet. `Tensor.to_torch()` is " - "usable only after you call `Runtime.launch()`." - ) - torch_type = self.dtype().to_torch() - if tensor is None: - dev_name = f"cuda:{rt.executor.device_id()}" - tensor = torch.zeros( - self.shape(), dtype=torch_type, device=torch.device(dev_name) - ) - elif list(tensor.shape) != self.shape(): - raise ValueError( - f"torch tensor shape {list(tensor.shape)} " - f"does not match the tensor {self.shape()}" - ) - elif tensor.dtype != torch_type: - raise ValueError( - f"torch tensor dtype {tensor.dtype} " - f"does not match the tensor {torch_type}" - ) - elif not tensor.is_contiguous(): - raise ValueError("torch tensor is not contiguous in memory") - elif tensor.numel() != self.nelems(): - raise ValueError( - f"torch tensor size {tensor.numel()} " - f"does not match the tensor {self.nelems()}" - ) - tensor_bytes = self.nelems() * self.dtype().element_size() - rt.executor.tensor_read( - self._tensor, tensor.data_ptr(), tensor_bytes, stream, True - ) - return tensor - - def get_torch_view(self) -> torch.Tensor: - """ - Returns a torch tensor that shares the same memory with the device tensor. - """ - if _no_torch: - raise ImportError("torch is not available") - rt = Runtime.get_runtime(self.runtime_id) - if not rt.launched(): - raise RuntimeError( - "Tensor is not allocated yet. `Tensor.get_torch_view()` is " - "usable only after you call `Runtime.launch()`." - ) - dl_tensor = rt.executor.get_dl_tensor(self._tensor) - torch_view = torch.utils.dlpack.from_dlpack(dl_tensor) - return torch_view - def from_numpy(self, ndarray: np.ndarray, stream: int = 0) -> "Tensor": """ Copies the tensor from a host numpy array to the device. @@ -177,6 +120,37 @@ def from_numpy(self, ndarray: np.ndarray, stream: int = 0) -> "Tensor": rt.executor.tensor_write(self._tensor, ndarray, stream) return self + def to_dlpack(self): + """ + Returns a DLPack tensor that shares the same memory with the device tensor. + """ + rt = Runtime.get_runtime(self.runtime_id) + if not rt.launched(): + raise RuntimeError( + "Tensor is not allocated yet. `Tensor.to_dlpack()` is " + "usable only after you call `Runtime.launch()`." + ) + return rt.executor.tensor_to_dlpack(self._tensor) + + @staticmethod + def from_dlpack(ext_tensor, runtime_id: int = -1) -> "Tensor": + """ + Copies the tensor from a DLPack tensor to the device. + """ + return Tensor(_Tensor(ext_tensor), runtime_id=runtime_id) + + def to_torch(self) -> torch.Tensor: + """ + Returns a torch tensor that shares the same memory with the device tensor. + """ + if _no_torch: + raise ImportError("torch is not available") + dl_capsule = self.to_dlpack() + torch_view = torch.utils.dlpack.from_dlpack(dl_capsule) + # Keep dl_capsule alive not to free the memory + torch_view.__ark_buffer__ = dl_capsule + return torch_view + @staticmethod def from_torch(tensor: torch.Tensor, runtime_id: int = -1) -> "Tensor": """ @@ -188,10 +162,10 @@ def from_torch(tensor: torch.Tensor, runtime_id: int = -1) -> "Tensor": raise ValueError("Torch tensor must be contiguous.") elif tensor.device.type == "cpu": raise ValueError("Torch tensor must be on a device.") - ark_dtype = DataType.from_torch(tensor.dtype) - dl_capsule = torch.utils.dlpack.to_dlpack(tensor) - ark_tensor = _Tensor(dl_capsule, ark_dtype.ctype()) - return Tensor(ark_tensor, runtime_id=runtime_id) + return Tensor.from_dlpack( + torch.utils.dlpack.to_dlpack(tensor), + runtime_id=runtime_id, + ) def copy( self, data: Union[np.ndarray, torch.Tensor], stream: int = 0 diff --git a/python/ark_py.cpp b/python/ark_py.cpp index 75788ba5..1bc4255d 100644 --- a/python/ark_py.cpp +++ b/python/ark_py.cpp @@ -7,7 +7,6 @@ namespace py = pybind11; -extern void register_plan_manager(py::module &m); extern void register_data_type(py::module &m); extern void register_dims(py::module &m); extern void register_error(py::module &m); @@ -23,7 +22,6 @@ extern void register_version(py::module &m); PYBIND11_MODULE(_ark_core, m) { m.doc() = "Bind ARK C++ APIs to Python"; - register_plan_manager(m); register_data_type(m); register_dims(m); register_error(m); diff --git a/python/executor_py.cpp b/python/executor_py.cpp index 126970d8..d90825e2 100644 --- a/python/executor_py.cpp +++ b/python/executor_py.cpp @@ -8,8 +8,10 @@ #include #include -#include -#include + +#include "gpu/gpu_memory.hpp" +#include "logging.hpp" + namespace py = pybind11; static void tensor_write(ark::Executor *exe, const ark::Tensor &tensor, @@ -42,37 +44,37 @@ static void tensor_read(ark::Executor *exe, const ark::Tensor &tensor, reinterpret_cast(stream), is_d2d); } -static DLDataType get_dl_dtype(const ark::DataType &ark_data_type) { - DLDataType dl_data_type; - dl_data_type.lanes = 1; - if (ark_data_type == ark::FP32) { - dl_data_type.code = kDLFloat; - dl_data_type.bits = 32; - } else if (ark_data_type == ark::FP16) { - dl_data_type.code = kDLFloat; - dl_data_type.bits = 16; - } else if (ark_data_type == ark::BF16) { - dl_data_type.code = kDLBfloat; - dl_data_type.bits = 16; - } else if (ark_data_type == ark::INT32) { - dl_data_type.code = kDLInt; - dl_data_type.bits = 32; - } else if (ark_data_type == ark::UINT32) { - dl_data_type.code = kDLUInt; - dl_data_type.bits = 32; - } else if (ark_data_type == ark::INT8) { - dl_data_type.code = kDLInt; - dl_data_type.bits = 8; - } else if (ark_data_type == ark::UINT8) { - dl_data_type.code = kDLUInt; - dl_data_type.bits = 8; - } else if (ark_data_type == ark::BYTE) { - dl_data_type.code = kDLUInt; - dl_data_type.bits = 8; +static DLDataType to_dl_dtype(const ark::DataType &ark_dtype) { + DLDataType dl_dtype; + dl_dtype.lanes = 1; + if (ark_dtype == ark::FP32) { + dl_dtype.code = kDLFloat; + dl_dtype.bits = 32; + } else if (ark_dtype == ark::FP16) { + dl_dtype.code = kDLFloat; + dl_dtype.bits = 16; + } else if (ark_dtype == ark::BF16) { + dl_dtype.code = kDLBfloat; + dl_dtype.bits = 16; + } else if (ark_dtype == ark::INT32) { + dl_dtype.code = kDLInt; + dl_dtype.bits = 32; + } else if (ark_dtype == ark::UINT32) { + dl_dtype.code = kDLUInt; + dl_dtype.bits = 32; + } else if (ark_dtype == ark::INT8) { + dl_dtype.code = kDLInt; + dl_dtype.bits = 8; + } else if (ark_dtype == ark::UINT8) { + dl_dtype.code = kDLUInt; + dl_dtype.bits = 8; + } else if (ark_dtype == ark::BYTE) { + dl_dtype.code = kDLUInt; + dl_dtype.bits = 8; } else { - throw std::runtime_error("unexpected error"); + ERR(ark::InternalError, "unexpected"); } - return dl_data_type; + return dl_dtype; } static DLDeviceType get_device_type() { @@ -85,66 +87,84 @@ static DLDeviceType get_device_type() { #endif } -static DLManagedTensor *to_dlpack(ark::Executor &exe, - const ark::Tensor &tensor) { - DLTensor dl_tensor; - dl_tensor.data = reinterpret_cast(exe.tensor_address(tensor)); - size_t offset_in_elements = - tensor.offsets().is_no_dim() ? 0 : tensor.offsets().vector()[0]; - dl_tensor.byte_offset = offset_in_elements * tensor.data_type().bytes(); - dl_tensor.device.device_type = get_device_type(); - dl_tensor.device.device_id = static_cast(exe.device_id()); - dl_tensor.ndim = static_cast(tensor.shape().ndims()); - dl_tensor.dtype = get_dl_dtype(tensor.data_type()); - - dl_tensor.shape = - tensor.shape().is_no_dim() ? nullptr : new int64_t[dl_tensor.ndim]; - dl_tensor.strides = - tensor.strides().is_no_dim() ? nullptr : new int64_t[dl_tensor.ndim]; - auto shape = tensor.shape(); - if (dl_tensor.shape) { - for (int i = 0; i < dl_tensor.ndim; ++i) { - dl_tensor.shape[i] = shape[i]; - } - } - if (dl_tensor.strides) { - dl_tensor.strides[dl_tensor.ndim - 1] = 1; - for (int i = dl_tensor.ndim - 2; i >= 0; --i) { - dl_tensor.strides[i] = - dl_tensor.shape[i + 1] * dl_tensor.strides[i + 1]; +namespace ark { + +class SharedTensor { + public: + SharedTensor(Executor &exe, const Tensor &tensor); + ~SharedTensor() = default; + + DLTensor dl_tensor() const; + + private: + std::shared_ptr buffer_; + void *data_; + int device_id_; + DataType dtype_; + std::shared_ptr> shape_; + std::shared_ptr> strides_; + std::shared_ptr> offsets_; +}; + +SharedTensor::SharedTensor(Executor &exe, const Tensor &tensor) { + buffer_ = exe.buffer(); + data_ = reinterpret_cast(exe.tensor_address(tensor)); + device_id_ = exe.device_id(); + dtype_ = tensor.data_type(); + shape_ = std::make_shared>(tensor.shape().vector()); + offsets_ = + std::make_shared>(tensor.offsets().vector()); + + strides_ = std::make_shared>(); + if (!shape_->empty()) { + int ndims = static_cast(shape_->size()); + strides_->resize(shape_->size()); + strides_->back() = 1; + auto tmp = tensor.strides().vector(); + for (int i = ndims - 2; i >= 0; --i) { + (*strides_)[i] = (*strides_)[i + 1] * tmp[i + 1]; } } - DLManagedTensor *dl_managed_tensor = new DLManagedTensor(); - dl_managed_tensor->dl_tensor = dl_tensor; - dl_managed_tensor->manager_ctx = nullptr; - dl_managed_tensor->deleter = [](DLManagedTensor *self) { - if (self->dl_tensor.shape) { - delete[] self->dl_tensor.shape; - self->dl_tensor.shape = nullptr; - } - if (self->dl_tensor.strides) { - delete[] self->dl_tensor.strides; - self->dl_tensor.strides = nullptr; - } - }; - return dl_managed_tensor; } -void free_capsule(PyObject *capsule) { - const char *name = PyCapsule_GetName(capsule); - auto *dl_managed_tensor = - static_cast(PyCapsule_GetPointer(capsule, name)); - if (dl_managed_tensor) { - dl_managed_tensor->deleter(dl_managed_tensor); - dl_managed_tensor = nullptr; - } +DLTensor SharedTensor::dl_tensor() const { + DLTensor dl_tensor; + dl_tensor.data = data_; + size_t offset_in_elements = offsets_->empty() ? 0 : offsets_->at(0); + dl_tensor.byte_offset = offset_in_elements * dtype_.bytes(); + dl_tensor.device.device_type = get_device_type(); + dl_tensor.device.device_id = device_id_; + dl_tensor.ndim = static_cast(shape_->size()); + dl_tensor.dtype = to_dl_dtype(dtype_); + dl_tensor.shape = shape_->data(); + dl_tensor.strides = strides_->data(); + return dl_tensor; } -py::capsule to_dlpack_capsule(ark::Executor &self, const ark::Tensor &tensor) { - DLManagedTensor *dl_managed_tensor = to_dlpack(self, tensor); +} // namespace ark + +static py::capsule tensor_to_dlpack(ark::Executor &self, const ark::Tensor &tensor) { + auto shared_tensor = new ark::SharedTensor(self, tensor); + DLManagedTensor *dl_managed_tensor = new DLManagedTensor(); + dl_managed_tensor->dl_tensor = shared_tensor->dl_tensor(); + dl_managed_tensor->manager_ctx = shared_tensor; + dl_managed_tensor->deleter = [](DLManagedTensor *self) { + if (self->manager_ctx) { + delete static_cast(self->manager_ctx); + self->manager_ctx = nullptr; + } + }; const char *capsule_name = "dltensor"; PyObject *dl_capsule = PyCapsule_New(static_cast(dl_managed_tensor), - capsule_name, free_capsule); + capsule_name, [](PyObject *capsule) { + const char *name = PyCapsule_GetName(capsule); + auto *dl_managed_tensor = static_cast( + PyCapsule_GetPointer(capsule, name)); + if (dl_managed_tensor) { + dl_managed_tensor->deleter(dl_managed_tensor); + dl_managed_tensor = nullptr; + } + }); return py::reinterpret_steal(dl_capsule); } @@ -191,5 +211,5 @@ void register_executor(py::module &m) { size_t, uintptr_t, bool>(&tensor_write), py::arg("tensor"), py::arg("address"), py::arg("bytes"), py::arg("stream"), py::arg("is_d2d")) - .def("get_dl_tensor", &to_dlpack_capsule); + .def("tensor_to_dlpack", &tensor_to_dlpack); } diff --git a/python/tensor_py.cpp b/python/tensor_py.cpp index 16eb0342..e7f06592 100644 --- a/python/tensor_py.cpp +++ b/python/tensor_py.cpp @@ -8,6 +8,8 @@ #include +#include "logging.hpp" + namespace py = pybind11; struct DLTensorMetadata { @@ -40,12 +42,37 @@ static DLTensorMetadata extractDLTensorMetadata(DLManagedTensor* dl_tensor) { return metadata; } +static ark::DataType from_dl_dtype(const DLDataType &dl_dtype) { + if (dl_dtype.lanes != 1) { + ERR(ark::UnsupportedError, "unsupported data type"); + } + ark::DataType ark_dtype; + if (dl_dtype.code == kDLFloat && dl_dtype.bits == 32) { + ark_dtype = ark::FP32; + } else if (dl_dtype.code == kDLFloat && dl_dtype.bits == 16) { + ark_dtype = ark::FP16; + } else if (dl_dtype.code == kDLBfloat && dl_dtype.bits == 16) { + ark_dtype = ark::BF16; + } else if (dl_dtype.code == kDLInt && dl_dtype.bits == 32) { + ark_dtype = ark::INT32; + } else if (dl_dtype.code == kDLUInt && dl_dtype.bits == 32) { + ark_dtype = ark::UINT32; + } else if (dl_dtype.code == kDLInt && dl_dtype.bits == 8) { + ark_dtype = ark::INT8; + } else if (dl_dtype.code == kDLUInt && dl_dtype.bits == 8) { + ark_dtype = ark::UINT8; + } else { + ERR(ark::UnsupportedError, "unsupported data type"); + } + return ark_dtype; +} + void register_tensor(py::module& m) { py::class_(m, "_Tensor") - .def(py::init([](py::capsule capsule, const ark::DataType& dtype) { + .def(py::init([](py::capsule capsule) { DLManagedTensor* dl_tensor = (DLManagedTensor*)capsule; if (!dl_tensor) { - throw std::runtime_error( + ERR(ark::InvalidUsageError, "Capsule does not contain a DLManagedTensor"); } DLTensorMetadata metadata = extractDLTensorMetadata(dl_tensor); @@ -53,7 +80,7 @@ void register_tensor(py::module& m) { void* data_ptr = metadata.data_ptr; auto shape = metadata.shape; - return new ark::Tensor(data_ptr, device_id, shape, dtype); + return ark::Tensor(data_ptr, device_id, shape, from_dl_dtype(metadata.dtype)); })) .def("id", &ark::Tensor::id) .def("shape", &ark::Tensor::shape, py::return_value_policy::reference) diff --git a/python/unittest/test.py b/python/unittest/test.py index 238b16fb..d56932b8 100644 --- a/python/unittest/test.py +++ b/python/unittest/test.py @@ -10,4 +10,4 @@ from test_error import * from test_model import * from test_runtime import * -from test_conversion import * +from test_tensor import * diff --git a/python/unittest/test_runtime.py b/python/unittest/test_runtime.py index 8c00b51f..c3d15d1b 100644 --- a/python/unittest/test_runtime.py +++ b/python/unittest/test_runtime.py @@ -20,99 +20,99 @@ def test_runtime_relaunch(): assert rt.launched() == True -def test_multiple_runtime_launch(): - ark.init() - num_runtimes = 5 - for i in range(num_runtimes): - rt = ark.Runtime.get_runtime(i) - assert rt.launched() == False - rt.launch(plan=empty_plan, device_id=i) - assert rt.launched() == True - for i in range(num_runtimes): - rt = ark.Runtime.get_runtime(i) - assert rt.launched() == True - ark.Runtime.delete_all_runtimes() - - -def test_stop_runtime(): - ark.init() - rt1 = ark.Runtime.get_runtime(1) - rt1.launch(plan=empty_plan, device_id=1) - rt2 = ark.Runtime.get_runtime(2) - rt2.launch(plan=empty_plan, device_id=2) - rt1.stop() - rt1.reset() - assert rt1.state == ark.Runtime.State.Init - assert rt2.state == ark.Runtime.State.LaunchedNotRunning - ark.Runtime.delete_all_runtimes() - - -def test_reset_runtime(): - ark.init() - rt1 = ark.Runtime.get_runtime(0) - rt1.launch(plan=empty_plan, device_id=1) - rt2 = ark.Runtime.get_runtime(1) - rt2.launch(plan=empty_plan, device_id=2) - rt1.reset() - assert rt1.launched() == False - assert rt2.launched() == True - rt1.launch(plan=empty_plan) - assert rt1.launched() == True - ark.Runtime.delete_all_runtimes() - - -def test_multiple_runtimes_complex(): - ark.init() - num_runtimes = 3 - runtime_list = [ark.Runtime.get_runtime(i) for i in range(num_runtimes)] - default_runtime = ark.Runtime.get_runtime() - runtime_list.append(default_runtime) - for i, rt in enumerate(runtime_list): - rt.launch(plan=empty_plan, device_id=i) - assert rt.launched() == True - runtime_list[0].stop() - assert runtime_list[0].state == ark.Runtime.State.LaunchedNotRunning - for rt in runtime_list[1:]: - assert rt.launched() == True - runtime_list[1].reset() - assert runtime_list[1].state == ark.Runtime.State.Init - assert runtime_list[0].state == ark.Runtime.State.LaunchedNotRunning - assert runtime_list[2].state == ark.Runtime.State.LaunchedNotRunning - runtime_list[1].launch(plan=empty_plan, device_id=1) - for rt in runtime_list: - assert rt.launched() == True - ark.Runtime.delete_all_runtimes() - - -def test_runtime_state_after_reset(): - ark.init() - rt = ark.Runtime.get_runtime() - rt.launch(plan=empty_plan) - rt.reset() - assert rt.launched() == False - assert rt.running() == False - ark.Runtime.delete_all_runtimes() - - -def test_see_runtime_statuses(): - ark.init() - num_runtimes = 3 - runtimes = [ark.Runtime.get_runtime(i) for i in range(num_runtimes)] - runtime_statuses = ark.Runtime.see_runtime_statuses() - assert len(runtime_statuses) == num_runtimes - for i in range(num_runtimes): - assert i in runtime_statuses - for i, rt in enumerate(runtimes): - assert runtime_statuses[i] == rt - ark.Runtime.delete_all_runtimes() - - -def test_multiple_runtimes_init(): - ark.init() - runtimes = [ark.Runtime.get_runtime(i) for i in range(3)] - for rt in runtimes: - assert rt.state == ark.Runtime.State.Init - ark.init() - runtimes = ark.Runtime.see_runtime_statuses() - assert len(runtimes) == 0 - ark.Runtime.delete_all_runtimes() +# def test_multiple_runtime_launch(): +# ark.init() +# num_runtimes = 5 +# for i in range(num_runtimes): +# rt = ark.Runtime.get_runtime(i) +# assert rt.launched() == False +# rt.launch(plan=empty_plan, device_id=i) +# assert rt.launched() == True +# for i in range(num_runtimes): +# rt = ark.Runtime.get_runtime(i) +# assert rt.launched() == True +# ark.Runtime.delete_all_runtimes() + + +# def test_stop_runtime(): +# ark.init() +# rt1 = ark.Runtime.get_runtime(1) +# rt1.launch(plan=empty_plan, device_id=1) +# rt2 = ark.Runtime.get_runtime(2) +# rt2.launch(plan=empty_plan, device_id=2) +# rt1.stop() +# rt1.reset() +# assert rt1.state == ark.Runtime.State.Init +# assert rt2.state == ark.Runtime.State.LaunchedNotRunning +# ark.Runtime.delete_all_runtimes() + + +# def test_reset_runtime(): +# ark.init() +# rt1 = ark.Runtime.get_runtime(0) +# rt1.launch(plan=empty_plan, device_id=1) +# rt2 = ark.Runtime.get_runtime(1) +# rt2.launch(plan=empty_plan, device_id=2) +# rt1.reset() +# assert rt1.launched() == False +# assert rt2.launched() == True +# rt1.launch(plan=empty_plan) +# assert rt1.launched() == True +# ark.Runtime.delete_all_runtimes() + + +# def test_multiple_runtimes_complex(): +# ark.init() +# num_runtimes = 3 +# runtime_list = [ark.Runtime.get_runtime(i) for i in range(num_runtimes)] +# default_runtime = ark.Runtime.get_runtime() +# runtime_list.append(default_runtime) +# for i, rt in enumerate(runtime_list): +# rt.launch(plan=empty_plan, device_id=i) +# assert rt.launched() == True +# runtime_list[0].stop() +# assert runtime_list[0].state == ark.Runtime.State.LaunchedNotRunning +# for rt in runtime_list[1:]: +# assert rt.launched() == True +# runtime_list[1].reset() +# assert runtime_list[1].state == ark.Runtime.State.Init +# assert runtime_list[0].state == ark.Runtime.State.LaunchedNotRunning +# assert runtime_list[2].state == ark.Runtime.State.LaunchedNotRunning +# runtime_list[1].launch(plan=empty_plan, device_id=1) +# for rt in runtime_list: +# assert rt.launched() == True +# ark.Runtime.delete_all_runtimes() + + +# def test_runtime_state_after_reset(): +# ark.init() +# rt = ark.Runtime.get_runtime() +# rt.launch(plan=empty_plan) +# rt.reset() +# assert rt.launched() == False +# assert rt.running() == False +# ark.Runtime.delete_all_runtimes() + + +# def test_see_runtime_statuses(): +# ark.init() +# num_runtimes = 3 +# runtimes = [ark.Runtime.get_runtime(i) for i in range(num_runtimes)] +# runtime_statuses = ark.Runtime.see_runtime_statuses() +# assert len(runtime_statuses) == num_runtimes +# for i in range(num_runtimes): +# assert i in runtime_statuses +# for i, rt in enumerate(runtimes): +# assert runtime_statuses[i] == rt +# ark.Runtime.delete_all_runtimes() + + +# def test_multiple_runtimes_init(): +# ark.init() +# runtimes = [ark.Runtime.get_runtime(i) for i in range(3)] +# for rt in runtimes: +# assert rt.state == ark.Runtime.State.Init +# ark.init() +# runtimes = ark.Runtime.see_runtime_statuses() +# assert len(runtimes) == 0 +# ark.Runtime.delete_all_runtimes() diff --git a/python/unittest/test_tensor.py b/python/unittest/test_tensor.py new file mode 100644 index 00000000..1acad43e --- /dev/null +++ b/python/unittest/test_tensor.py @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from unittest_common import pytest_ark +import ark + + +@pytest_ark(need_torch=True) +def test_tensor_torch(): + import torch + + ones = torch.ones(2, 1024, device=torch.device("cuda:0")) + + t = ark.Tensor.from_torch(ones) + t = ark.mul(t, 5) + + with ark.Runtime() as rt: + rt.launch() + rt.run() + + x = t.to_torch() + + assert torch.allclose(x, ones * 5)