Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Aug 11, 2024
1 parent c0cbf19 commit 8583d1b
Show file tree
Hide file tree
Showing 17 changed files with 349 additions and 336 deletions.
4 changes: 4 additions & 0 deletions ark/api/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ class Executor::Impl {

Stream stream() const { return reinterpret_cast<Stream>(stream_raw_); }

std::shared_ptr<GpuMemory> buffer() const { return buffer_; }

std::string plan() const { return plan_json_.dump_pretty(); }

void compile();
Expand Down Expand Up @@ -934,6 +936,8 @@ int Executor::device_id() const { return impl_->device_id(); }

Stream Executor::stream() const { return impl_->stream(); }

std::shared_ptr<GpuMemory> Executor::buffer() const { return impl_->buffer(); }

std::string Executor::plan() const { return impl_->plan(); }

void Executor::compile() { impl_->compile(); }
Expand Down
5 changes: 5 additions & 0 deletions ark/include/ark/executor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ namespace ark {

using Stream = void *;

class GpuMemory;

/// Convenience class for executing a model.
class Executor {
public:
Expand All @@ -31,6 +33,9 @@ class Executor {
/// Return the stream of the executor.
Stream stream() const;

/// Return the buffer of the executor.
std::shared_ptr<GpuMemory> buffer() const;

/// Return the plan string.
std::string plan() const;

Expand Down
13 changes: 6 additions & 7 deletions examples/tutorial/planner_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,13 @@ def perf():

shape = (32, 2048, 2048)

# input = torch.randn(*shape).to("cuda:0")
input = ark.tensor(shape)
input = torch.randn(*shape).to("cuda:0")

output = Softmax()(input)
output = Softmax()(ark.Tensor.from_torch(input))

# if torch.allclose(eval(output), F.softmax(input, dim=-1), atol=1e-5):
# print("Correct result")
# else:
# print("Incorrect result")
if torch.allclose(eval(output), F.softmax(input, dim=-1), atol=1e-5):
print("Correct result")
else:
print("Incorrect result")

print(f"Performance: {(perf() * 1e3):.3f} ms/iter")
13 changes: 13 additions & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,16 @@ pybind11_add_module(ark_py ${BIND_SOURCES})
set_target_properties(ark_py PROPERTIES OUTPUT_NAME _ark_core)
target_link_libraries(ark_py PRIVATE ark_static)
target_include_directories(ark_py SYSTEM PRIVATE ${DLPACK_INCLUDE_DIRS})
target_include_directories(ark_py PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../ark)

if(ARK_USE_CUDA)
target_include_directories(ark_py SYSTEM PRIVATE
${CUDAToolkit_INCLUDE_DIRS}
)
endif()

if(ARK_USE_ROCM)
target_include_directories(ark_py SYSTEM PRIVATE
/opt/rocm/include
)
endif()
5 changes: 4 additions & 1 deletion python/ark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import sys
import os

if os.environ.get("ARK_ROOT", None) is None:
os.environ["ARK_ROOT"] = os.path.abspath(os.path.dirname(__file__))

from . import _ark_core
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

import _ark_core
from .model import Model


Expand Down
2 changes: 1 addition & 1 deletion python/ark/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT license.

import numpy
from . import _ark_core
import _ark_core

try:
import torch
Expand Down
2 changes: 1 addition & 1 deletion python/ark/init.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from . import _ark_core
import _ark_core
from .model import Model
from .runtime import _RuntimeState

Expand Down
2 changes: 1 addition & 1 deletion python/ark/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT license.

from typing import NewType
from ._ark_core import _Model
from _ark_core import _Model

_ModelState = NewType("_ModelState", None)

Expand Down
4 changes: 2 additions & 2 deletions python/ark/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from typing import Any, Dict, List, Union
from .tensor import Tensor, Parameter
from .runtime import Runtime, DefaultPlanner
from .runtime import Runtime, Planner
from .ops import tensor
from .data_type import DataType

Expand Down Expand Up @@ -183,7 +183,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
self.built_forward = True

with Runtime.get_runtime() as rt:
rt.launch(plan=DefaultPlanner().plan())
rt.launch(plan=Planner().plan())
for tns, arg in zip(self.forward_input_tensor_args, args):
tns.copy(arg)
for key, value in self.forward_input_tensor_kwargs.items():
Expand Down
Loading

0 comments on commit 8583d1b

Please sign in to comment.