Skip to content

Commit

Permalink
Minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Aug 11, 2024
1 parent 598cb78 commit 28ce027
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 52 deletions.
2 changes: 1 addition & 1 deletion ark/api/executor_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ ark::unittest::State test_executor_tensor_read_write(ark::Dims shape,
ark::DefaultExecutor executor(m, 0);
executor.compile();
executor.launch();
UNITTEST_GT(executor.tensor_address(tensor), 0);
UNITTEST_NE(executor.tensor_address(tensor), nullptr);

// Copy data from CPU array to ARK tensor
executor.tensor_write(tensor, host_data.data(),
Expand Down
1 change: 1 addition & 0 deletions examples/tutorial/model_test_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# Set random seed for reproducibility.
torch.manual_seed(42)


# Let's first define a linear layer using ARK.
class ARKLinear(ark.Module):
def __init__(self, weight):
Expand Down
4 changes: 2 additions & 2 deletions python/ark/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from .runtime import _RuntimeState


def init(keep_runtime: bool = False):
def init():
"""Initializes ARK."""
Model.reset()
if not keep_runtime and _RuntimeState.runtime is not None:
if _RuntimeState.runtime is not None:
del _RuntimeState.runtime
_RuntimeState.runtime = None
_ark_core.init()
41 changes: 18 additions & 23 deletions python/ark/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import numpy as np
from typing import Any, Dict, Union
from .tensor import Tensor, Parameter
from .runtime import Runtime, Planner
from .runtime import Runtime
from .init import init
from .ops import tensor
from .data_type import DataType
from .model import Model

try:
import torch
Expand Down Expand Up @@ -78,6 +77,7 @@ def load_state_dict(
self,
state_dict: Dict[str, Union[np.ndarray, torch.Tensor]],
prefix: str = "",
stream: int = 0,
):
"""
Loads a model from a state_dict and copy the parameters to the device GPU.
Expand All @@ -91,15 +91,18 @@ def load_state_dict(
data = state_dict.get(name, None)
if data is None:
continue
param.copy(data)
param.copy(data, stream=stream)
all_keys.remove(name)
if all_keys:
logging.warning(
f"{len(all_keys)} unused parameter(s) in state_dict"
)

def state_dict(
self, prefix: str = "", mode: str = "numpy"
self,
prefix: str = "",
mode: str = "numpy",
stream: int = 0,
) -> Dict[str, Union[np.ndarray, torch.Tensor]]:
"""
Copies the parameters from the device GPU to the host and saves the
Expand All @@ -108,11 +111,13 @@ def state_dict(
"""
if mode == "numpy":
return {
k: v.to_numpy() for k, v in self.params_dict(prefix).items()
k: v.to_numpy(stream=stream)
for k, v in self.params_dict(prefix).items()
}
elif mode == "torch":
return {
k: v.to_torch() for k, v in self.params_dict(prefix).items()
k: v.to_torch(stream=stream)
for k, v in self.params_dict(prefix).items()
}
raise ValueError(f"Unsupported mode: {mode}")

Expand All @@ -127,17 +132,7 @@ def initialize(self):
module.initialize()


def _recursive_ark_to_torch(object):
if isinstance(object, Tensor):
return object.to_torch()
if isinstance(object, dict):
return {k: _recursive_ark_to_torch(v) for k, v in object.items()}
if isinstance(object, list):
return [_recursive_ark_to_torch(v) for v in object]
return object


class _ARKFunction(torch.autograd.Function):
class _Function(torch.autograd.Function):
"""
Facilitates the integration of ARK modules with PyTorch's
autograd system by defining custom forward and backward passes that
Expand All @@ -150,7 +145,7 @@ def forward(ctx, ark_module, *args, **kwargs):
Returns a PyTorch tensor that is the result
of the forward pass of the ARK module.
"""
init(keep_runtime=True)
Model.reset()
ctx.ark_module = ark_module
input_args, input_kwargs = [], {}
input_requires_grad = 0
Expand Down Expand Up @@ -184,12 +179,12 @@ def backward(ctx, *grad_outputs):
and parameters using the ARK module backwards pass, and updates the gradients of the corresponding
PyTorch parameters.
"""
init(keep_runtime=True)
Model.reset()
ark_grad_outputs = [Tensor.from_torch(grad) for grad in grad_outputs]
grads = ctx.ark_module.backward(*ark_grad_outputs)
grad_inputs, grad_weights = (
grads[:ctx.num_inp_grad],
grads[ctx.num_inp_grad:],
grads[: ctx.num_inp_grad],
grads[ctx.num_inp_grad :],
)
params_dict = ctx.ark_module.params_dict()
rt = Runtime.get_runtime()
Expand All @@ -214,4 +209,4 @@ def __init__(self, ark_module):
self.ark_module = ark_module

def forward(self, *args, **kwargs):
return _ARKFunction.apply(self.ark_module, *args, **kwargs)
return _Function.apply(self.ark_module, *args, **kwargs)
8 changes: 2 additions & 6 deletions python/ark/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import logging
from enum import Enum
from typing import Dict, List

from _ark_core import _Executor
from .planner import Planner, Plan
Expand Down Expand Up @@ -169,9 +168,9 @@ def stop(self) -> float:
self.state = Runtime.State.LaunchedNotRunning
return elapsed

def reset(self, delete=False, persist=False):
def reset(self, persist=False):
"""
Reset the runtime. If delete is True, delete the runtime.
Reset the runtime.
"""
if self.launched():
self.stop()
Expand All @@ -182,6 +181,3 @@ def reset(self, delete=False, persist=False):
self.executor.destroy()
self.executor = None
self.state = Runtime.State.Init
if delete:
del _RuntimeState.runtime
_RuntimeState.runtime = None
8 changes: 4 additions & 4 deletions python/ark/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,10 @@ class Parameter(Tensor, torch.nn.Parameter):
"""
A tensor as a parameter.
"""

def __init__(
self, tensor: Union[_Tensor, "torch.nn.Parameter"],
self,
tensor: Union[_Tensor, "torch.nn.Parameter"],
):
"""
Initializes a new instance of the Parameter class.
Expand All @@ -237,9 +239,7 @@ def __init__(
core_tensor = tensor
self.torch_param = None
self.staged_tensor = None
Tensor.__init__(
self, core_tensor, requires_grad=False
)
Tensor.__init__(self, core_tensor, requires_grad=False)
else:
raise TypeError(
"tensor must be an ARK tensor or a torch.nn.Parameter"
Expand Down
10 changes: 2 additions & 8 deletions python/ark/torch_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,15 @@ class ubyte: ...
class Tensor: ...



class nn:


class Module: ...


class Parameter: ...
class Parameter: ...


class autograd:


class Function:

class Function:

def apply(self, *args, **kwargs): ...

9 changes: 7 additions & 2 deletions python/executor_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,13 @@ void register_executor(py::module &m) {
.def("barrier", &ark::Executor::barrier)
.def("destroy", &ark::Executor::destroy)
.def("destroyed", &ark::Executor::destroyed)
.def("tensor_address", &ark::Executor::tensor_address,
py::arg("tensor"))
.def(
"tensor_address",
[](ark::Executor *self, const ark::Tensor &tensor) {
return reinterpret_cast<uintptr_t>(
self->tensor_address(tensor));
},
py::arg("tensor"))
.def("tensor_read",
py::overload_cast<ark::Executor *, const ark::Tensor &, py::buffer,
uintptr_t>(&tensor_read),
Expand Down
9 changes: 3 additions & 6 deletions python/unittest/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
import numpy as np


empty_plan = ark.Plan(None)


def test_runtime_relaunch():
ark.init()
with ark.Runtime.get_runtime() as rt:
Expand Down Expand Up @@ -39,7 +36,7 @@ def test_add_plans():
output_tensor_host, input_tensor_host + other_tensor_host
)
runtime.reset(persist=True)
ark.init(keep_runtime=True)
ark.Model.reset()
prev_output = output_tensor
new_tensor = ark.tensor([M, N], ark.fp16)
final_output = ark.add(prev_output, new_tensor)
Expand All @@ -53,6 +50,7 @@ def test_add_plans():
)
runtime.reset()


def test_reuse_plans():
ark.init()
M, N = 64, 64
Expand All @@ -71,12 +69,11 @@ def test_reuse_plans():
output_tensor_host, input_tensor_host + other_tensor_host
)
runtime.reset(persist=True)
ark.init(keep_runtime=True)
ark.Model.reset()
runtime.launch()
runtime.run()
output_tensor_host = output_tensor.to_numpy()
np.testing.assert_allclose(
output_tensor_host, input_tensor_host + other_tensor_host
)
runtime.reset()

0 comments on commit 28ce027

Please sign in to comment.