diff --git a/extension/pybindings/test/make_test.py b/extension/pybindings/test/make_test.py index 708e67e430..44e41ed443 100644 --- a/extension/pybindings/test/make_test.py +++ b/extension/pybindings/test/make_test.py @@ -7,10 +7,11 @@ # pyre-unsafe import unittest -from typing import Any, Callable, Tuple +from typing import Any, Callable, Optional, Tuple import torch -from executorch.exir import ExecutorchProgramManager, to_edge +from executorch.exir import ExecutorchBackendConfig, ExecutorchProgramManager, to_edge +from executorch.exir.passes import MemoryPlanningPass from torch.export import export @@ -75,8 +76,25 @@ def get_methods_to_export(self): def get_inputs(self): return (torch.ones(2, 2),) + class ModuleAddConstReturn(torch.nn.Module): + """The module to serialize and execute.""" + + def __init__(self): + super(ModuleAddConstReturn, self).__init__() + self.state = torch.ones(2, 2) + + def forward(self, x): + return x + self.state, self.state + + def get_methods_to_export(self): + return ("forward",) + + def get_inputs(self): + return (torch.ones(2, 2),) + def create_program( eager_module: torch.nn.Module, + et_config: Optional[ExecutorchBackendConfig] = None, ) -> Tuple[ExecutorchProgramManager, Tuple[Any, ...]]: """Returns an executorch program based on ModuleAdd, along with inputs.""" @@ -103,7 +121,7 @@ def forward(self, *args, **kwargs): ) exported_methods[method_name] = export(wrapped_mod, method_input) - exec_prog = to_edge(exported_methods).to_executorch() + exec_prog = to_edge(exported_methods).to_executorch(config=et_config) # Create the ExecuTorch program from the graph. exec_prog.dump_executorch_program(verbose=True) @@ -251,6 +269,34 @@ def test_quantized_ops(tester): expected = example_inputs[0] + example_inputs[1] tester.assertEqual(str(expected), str(executorch_output)) + def test_constant_output_not_memory_planned(tester): + # Create an ExecuTorch program from ModuleAdd. + exported_program, inputs = create_program( + ModuleAddConstReturn(), + et_config=ExecutorchBackendConfig( + memory_planning_pass=MemoryPlanningPass(alloc_graph_output=False) + ), + ) + + exported_program.dump_executorch_program(verbose=True) + + # Use pybindings to load and execute the program. + executorch_module = load_fn(exported_program.buffer) + # Invoke the callable on executorch_module instead of calling module.forward. + # Use only one input to test this case. + executorch_output = executorch_module((torch.ones(2, 2),)) + print(executorch_output) + + # The test module adds the input to torch.ones(2,2), so its output should be the same + # as adding them directly. + expected = torch.ones(2, 2) + torch.ones(2, 2) + tester.assertEqual(str(expected), str(executorch_output[0])) + + # The test module returns the state. Check that its value is correct. + tester.assertEqual(str(torch.ones(2, 2)), str(executorch_output[1])) + + ######### RUN TEST CASES ######### + test_e2e(tester) test_multiple_entry(tester) test_output_lifespan(tester) @@ -258,5 +304,6 @@ def test_quantized_ops(tester): test_module_single_input(tester) test_stderr_redirect(tester) test_quantized_ops(tester) + test_constant_output_not_memory_planned(tester) return wrapper diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index a6ed7e354a..bddcdc3173 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -744,40 +744,6 @@ Error Method::init(executorch_flatbuffer::ExecutionPlan* s_plan) { } } - // Validate input values and get tensor pre-allocation info. - pre_allocated_input_ = false; - for (int i = 0; i < inputs_size(); i++) { - // get_input() will panic if the index is invalid, so do this manually. - size_t index = get_input_index(i); - ET_CHECK_OR_RETURN_ERROR( - index < n_value_, - InvalidProgram, - "Input index %zu >= %zu", - index, - n_value_); - const EValue& input = values_[index]; - if (input.isTensor()) { - pre_allocated_input_ |= input.toTensor().const_data_ptr() != nullptr; - } - } - - // Validate output values and get tensor pre-allocation info. - pre_allocated_output_ = false; - for (int i = 0; i < outputs_size(); i++) { - // get_output() will panic if the index is invalid, so do this manually. - size_t index = get_output_index(i); - ET_CHECK_OR_RETURN_ERROR( - index < n_value_, - InvalidProgram, - "output index %zu >= %zu", - index, - n_value_); - const EValue& output = values_[index]; - if (output.isTensor()) { - pre_allocated_output_ |= output.toTensor().const_data_ptr() != nullptr; - } - } - step_state_ = StepState{0, 0}; init_state_ = InitializationState::Initialized; @@ -841,7 +807,8 @@ Method::set_input(const EValue& input_evalue, size_t input_idx) { input_idx, static_cast(err)); Error error; - if (pre_allocated_input_) { + auto tensor_meta = this->method_meta().input_tensor_meta(input_idx); + if (tensor_meta->is_memory_planned()) { error = internal::copy_tensor_data(t_dst, t_src); } else { error = internal::share_tensor_data(t_dst, t_src); @@ -950,21 +917,11 @@ Method::set_output_data_ptr(void* buffer, size_t size, size_t output_idx) { InvalidState, "Outputs can not be retrieved until method has been initialized."); - // ET_CHECK_OR_RETURN_ERROR( - // !pre_allocated_output_, - // InvalidState, - // "Overriding output data pointer allocated by memory plan is not - // allowed."); - // TODO(T188740925): for now, return error without logs. - if (pre_allocated_output_) { - return Error::InvalidState; - } - // Check the args ET_CHECK_OR_RETURN_ERROR( - output_idx <= outputs_size(), + output_idx < outputs_size(), InvalidArgument, - "output_idx: %zu num_outputs: %zu", + "output_idx: %zu > num_outputs: %zu", output_idx, outputs_size()); @@ -975,6 +932,16 @@ Method::set_output_data_ptr(void* buffer, size_t size, size_t output_idx) { "output type: %zu is not tensor", (size_t)output.tag); + auto tensor_meta = this->method_meta().output_tensor_meta(output_idx); + if (tensor_meta->is_memory_planned()) { + ET_LOG( + Error, + "Output %zu is memory planned, or is a constant. Cannot override \ + the existing data pointer.", + output_idx); + return Error::InvalidState; + } + auto& t = output.toTensor(); ET_CHECK_OR_RETURN_ERROR( output.isTensor(), diff --git a/runtime/executor/method.h b/runtime/executor/method.h index 0a35d6b928..66e3c96d29 100644 --- a/runtime/executor/method.h +++ b/runtime/executor/method.h @@ -62,9 +62,7 @@ class Method final { delegates_(rhs.delegates_), n_chains_(rhs.n_chains_), chains_(rhs.chains_), - init_state_(rhs.init_state_), - pre_allocated_input_(rhs.pre_allocated_input_), - pre_allocated_output_(rhs.pre_allocated_output_) { + init_state_(rhs.init_state_) { // Required: clear out fields that the dtor looks at, so that we don't free // anything twice. rhs.n_value_ = 0; @@ -82,8 +80,6 @@ class Method final { rhs.event_tracer_ = nullptr; rhs.n_chains_ = 0; rhs.chains_ = nullptr; - rhs.pre_allocated_input_ = false; - rhs.pre_allocated_output_ = false; } /** @@ -288,9 +284,7 @@ class Method final { delegates_(nullptr), n_chains_(0), chains_(nullptr), - init_state_(InitializationState::Uninitialized), - pre_allocated_input_(false), - pre_allocated_output_(false) {} + init_state_(InitializationState::Uninitialized) {} /// Static factory used by Program. ET_NODISCARD static Result load( @@ -336,8 +330,6 @@ class Method final { Chain* chains_; InitializationState init_state_; - bool pre_allocated_input_; - bool pre_allocated_output_; /** * Parses the elements of the values_ array. On error, n_value_ will be set to diff --git a/runtime/executor/method_meta.cpp b/runtime/executor/method_meta.cpp index 309ecf0ec8..5acf055a89 100644 --- a/runtime/executor/method_meta.cpp +++ b/runtime/executor/method_meta.cpp @@ -139,7 +139,9 @@ Result MethodMeta::input_tensor_meta(size_t index) const { Span( tensor_value->dim_order()->data(), tensor_value->dim_order()->size()), static_cast(tensor_value->scalar_type()), - tensor_value->allocation_info() != nullptr); + tensor_value->allocation_info() != nullptr || + tensor_value->data_buffer_idx() != + 0); // Count constant returns as memory planned. } size_t MethodMeta::num_outputs() const { @@ -170,15 +172,18 @@ Result MethodMeta::output_tensor_meta(size_t index) const { "Tag: %zu output: %zu is not Tensor", (size_t)tag.get(), index); - auto input_index = s_plan_->outputs()->Get(index); - auto tensor_value = s_plan_->values()->Get(input_index)->val_as_Tensor(); + auto output_index = s_plan_->outputs()->Get(index); + auto tensor_value = s_plan_->values()->Get(output_index)->val_as_Tensor(); + return TensorInfo( Span( tensor_value->sizes()->data(), tensor_value->sizes()->size()), Span( tensor_value->dim_order()->data(), tensor_value->dim_order()->size()), static_cast(tensor_value->scalar_type()), - tensor_value->allocation_info() != nullptr); + tensor_value->allocation_info() != nullptr || + tensor_value->data_buffer_idx() != + 0); // Count constant returns as memory planned. } size_t MethodMeta::num_memory_planned_buffers() const {