Skip to content

Commit

Permalink
Use TensorMeta to check if inputs and outputs are memory planned (#6114)
Browse files Browse the repository at this point in the history
Use TensorMeta to check if inputs and outputs are memory planned (#5565)

Summary:
Pull Request resolved: #5565

Swap to using method meta so we can be finer grained about this check

Reviewed By: dbort

Differential Revision: D62983475

fbshipit-source-id: c4599c5ecad0409cd8b2670464c4e9e8809b49ad
(cherry picked from commit df72b8c)

Co-authored-by: Jacob Szwejbka <[email protected]>
  • Loading branch information
pytorchbot and JacobSzwejbka authored Oct 11, 2024
1 parent e0fcdd4 commit 67c959a
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 64 deletions.
53 changes: 50 additions & 3 deletions extension/pybindings/test/make_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
# pyre-unsafe

import unittest
from typing import Any, Callable, Tuple
from typing import Any, Callable, Optional, Tuple

import torch
from executorch.exir import ExecutorchProgramManager, to_edge
from executorch.exir import ExecutorchBackendConfig, ExecutorchProgramManager, to_edge
from executorch.exir.passes import MemoryPlanningPass
from torch.export import export


Expand Down Expand Up @@ -75,8 +76,25 @@ def get_methods_to_export(self):
def get_inputs(self):
return (torch.ones(2, 2),)

class ModuleAddConstReturn(torch.nn.Module):
"""The module to serialize and execute."""

def __init__(self):
super(ModuleAddConstReturn, self).__init__()
self.state = torch.ones(2, 2)

def forward(self, x):
return x + self.state, self.state

def get_methods_to_export(self):
return ("forward",)

def get_inputs(self):
return (torch.ones(2, 2),)

def create_program(
eager_module: torch.nn.Module,
et_config: Optional[ExecutorchBackendConfig] = None,
) -> Tuple[ExecutorchProgramManager, Tuple[Any, ...]]:
"""Returns an executorch program based on ModuleAdd, along with inputs."""

Expand All @@ -103,7 +121,7 @@ def forward(self, *args, **kwargs):
)
exported_methods[method_name] = export(wrapped_mod, method_input)

exec_prog = to_edge(exported_methods).to_executorch()
exec_prog = to_edge(exported_methods).to_executorch(config=et_config)

# Create the ExecuTorch program from the graph.
exec_prog.dump_executorch_program(verbose=True)
Expand Down Expand Up @@ -251,12 +269,41 @@ def test_quantized_ops(tester):
expected = example_inputs[0] + example_inputs[1]
tester.assertEqual(str(expected), str(executorch_output))

def test_constant_output_not_memory_planned(tester):
# Create an ExecuTorch program from ModuleAdd.
exported_program, inputs = create_program(
ModuleAddConstReturn(),
et_config=ExecutorchBackendConfig(
memory_planning_pass=MemoryPlanningPass(alloc_graph_output=False)
),
)

exported_program.dump_executorch_program(verbose=True)

# Use pybindings to load and execute the program.
executorch_module = load_fn(exported_program.buffer)
# Invoke the callable on executorch_module instead of calling module.forward.
# Use only one input to test this case.
executorch_output = executorch_module((torch.ones(2, 2),))
print(executorch_output)

# The test module adds the input to torch.ones(2,2), so its output should be the same
# as adding them directly.
expected = torch.ones(2, 2) + torch.ones(2, 2)
tester.assertEqual(str(expected), str(executorch_output[0]))

# The test module returns the state. Check that its value is correct.
tester.assertEqual(str(torch.ones(2, 2)), str(executorch_output[1]))

######### RUN TEST CASES #########

test_e2e(tester)
test_multiple_entry(tester)
test_output_lifespan(tester)
test_module_callable(tester)
test_module_single_input(tester)
test_stderr_redirect(tester)
test_quantized_ops(tester)
test_constant_output_not_memory_planned(tester)

return wrapper
61 changes: 14 additions & 47 deletions runtime/executor/method.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -744,40 +744,6 @@ Error Method::init(executorch_flatbuffer::ExecutionPlan* s_plan) {
}
}

// Validate input values and get tensor pre-allocation info.
pre_allocated_input_ = false;
for (int i = 0; i < inputs_size(); i++) {
// get_input() will panic if the index is invalid, so do this manually.
size_t index = get_input_index(i);
ET_CHECK_OR_RETURN_ERROR(
index < n_value_,
InvalidProgram,
"Input index %zu >= %zu",
index,
n_value_);
const EValue& input = values_[index];
if (input.isTensor()) {
pre_allocated_input_ |= input.toTensor().const_data_ptr() != nullptr;
}
}

// Validate output values and get tensor pre-allocation info.
pre_allocated_output_ = false;
for (int i = 0; i < outputs_size(); i++) {
// get_output() will panic if the index is invalid, so do this manually.
size_t index = get_output_index(i);
ET_CHECK_OR_RETURN_ERROR(
index < n_value_,
InvalidProgram,
"output index %zu >= %zu",
index,
n_value_);
const EValue& output = values_[index];
if (output.isTensor()) {
pre_allocated_output_ |= output.toTensor().const_data_ptr() != nullptr;
}
}

step_state_ = StepState{0, 0};

init_state_ = InitializationState::Initialized;
Expand Down Expand Up @@ -841,7 +807,8 @@ Method::set_input(const EValue& input_evalue, size_t input_idx) {
input_idx,
static_cast<uint32_t>(err));
Error error;
if (pre_allocated_input_) {
auto tensor_meta = this->method_meta().input_tensor_meta(input_idx);
if (tensor_meta->is_memory_planned()) {
error = internal::copy_tensor_data(t_dst, t_src);
} else {
error = internal::share_tensor_data(t_dst, t_src);
Expand Down Expand Up @@ -950,21 +917,11 @@ Method::set_output_data_ptr(void* buffer, size_t size, size_t output_idx) {
InvalidState,
"Outputs can not be retrieved until method has been initialized.");

// ET_CHECK_OR_RETURN_ERROR(
// !pre_allocated_output_,
// InvalidState,
// "Overriding output data pointer allocated by memory plan is not
// allowed.");
// TODO(T188740925): for now, return error without logs.
if (pre_allocated_output_) {
return Error::InvalidState;
}

// Check the args
ET_CHECK_OR_RETURN_ERROR(
output_idx <= outputs_size(),
output_idx < outputs_size(),
InvalidArgument,
"output_idx: %zu num_outputs: %zu",
"output_idx: %zu > num_outputs: %zu",
output_idx,
outputs_size());

Expand All @@ -975,6 +932,16 @@ Method::set_output_data_ptr(void* buffer, size_t size, size_t output_idx) {
"output type: %zu is not tensor",
(size_t)output.tag);

auto tensor_meta = this->method_meta().output_tensor_meta(output_idx);
if (tensor_meta->is_memory_planned()) {
ET_LOG(
Error,
"Output %zu is memory planned, or is a constant. Cannot override \
the existing data pointer.",
output_idx);
return Error::InvalidState;
}

auto& t = output.toTensor();
ET_CHECK_OR_RETURN_ERROR(
output.isTensor(),
Expand Down
12 changes: 2 additions & 10 deletions runtime/executor/method.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ class Method final {
delegates_(rhs.delegates_),
n_chains_(rhs.n_chains_),
chains_(rhs.chains_),
init_state_(rhs.init_state_),
pre_allocated_input_(rhs.pre_allocated_input_),
pre_allocated_output_(rhs.pre_allocated_output_) {
init_state_(rhs.init_state_) {
// Required: clear out fields that the dtor looks at, so that we don't free
// anything twice.
rhs.n_value_ = 0;
Expand All @@ -82,8 +80,6 @@ class Method final {
rhs.event_tracer_ = nullptr;
rhs.n_chains_ = 0;
rhs.chains_ = nullptr;
rhs.pre_allocated_input_ = false;
rhs.pre_allocated_output_ = false;
}

/**
Expand Down Expand Up @@ -288,9 +284,7 @@ class Method final {
delegates_(nullptr),
n_chains_(0),
chains_(nullptr),
init_state_(InitializationState::Uninitialized),
pre_allocated_input_(false),
pre_allocated_output_(false) {}
init_state_(InitializationState::Uninitialized) {}

/// Static factory used by Program.
ET_NODISCARD static Result<Method> load(
Expand Down Expand Up @@ -336,8 +330,6 @@ class Method final {
Chain* chains_;

InitializationState init_state_;
bool pre_allocated_input_;
bool pre_allocated_output_;

/**
* Parses the elements of the values_ array. On error, n_value_ will be set to
Expand Down
13 changes: 9 additions & 4 deletions runtime/executor/method_meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ Result<TensorInfo> MethodMeta::input_tensor_meta(size_t index) const {
Span<const uint8_t>(
tensor_value->dim_order()->data(), tensor_value->dim_order()->size()),
static_cast<exec_aten::ScalarType>(tensor_value->scalar_type()),
tensor_value->allocation_info() != nullptr);
tensor_value->allocation_info() != nullptr ||
tensor_value->data_buffer_idx() !=
0); // Count constant returns as memory planned.
}

size_t MethodMeta::num_outputs() const {
Expand Down Expand Up @@ -170,15 +172,18 @@ Result<TensorInfo> MethodMeta::output_tensor_meta(size_t index) const {
"Tag: %zu output: %zu is not Tensor",
(size_t)tag.get(),
index);
auto input_index = s_plan_->outputs()->Get(index);
auto tensor_value = s_plan_->values()->Get(input_index)->val_as_Tensor();
auto output_index = s_plan_->outputs()->Get(index);
auto tensor_value = s_plan_->values()->Get(output_index)->val_as_Tensor();

return TensorInfo(
Span<const int32_t>(
tensor_value->sizes()->data(), tensor_value->sizes()->size()),
Span<const uint8_t>(
tensor_value->dim_order()->data(), tensor_value->dim_order()->size()),
static_cast<exec_aten::ScalarType>(tensor_value->scalar_type()),
tensor_value->allocation_info() != nullptr);
tensor_value->allocation_info() != nullptr ||
tensor_value->data_buffer_idx() !=
0); // Count constant returns as memory planned.
}

size_t MethodMeta::num_memory_planned_buffers() const {
Expand Down

0 comments on commit 67c959a

Please sign in to comment.