From d9ab717de87e296b8e20387a282318c267698752 Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Fri, 26 Jul 2024 14:16:30 -0700 Subject: [PATCH] Add exportable baby llama example (#4345) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4345 Add a small LLaMa model, based on the babyllama paper. Note that this test case is only one layer by default, and the number of layers can be adjusted in the test. Removed some pyre changes that broke the OSS AoT export, and added some required passes and operators. Differential Revision: D60073137 --- backends/cadence/aot/compiler.py | 20 ++-- backends/cadence/aot/functions.yaml | 20 ++++ backends/cadence/aot/passes.py | 103 +++++++++++++++++- backends/cadence/aot/quantizer/quantizer.py | 24 ++-- .../reference/operators/CMakeLists.txt | 6 +- .../operators/quantized_matmul_out.cpp | 42 +++---- examples/cadence/models/babyllama.py | 39 +++++++ 7 files changed, 206 insertions(+), 48 deletions(-) create mode 100644 examples/cadence/models/babyllama.py diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 302252c42a..39511ae917 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -11,23 +11,23 @@ import torch from executorch.backends.cadence.aot.passes import ( + InitializePipeline, + RemoveNopExpandOpPass, RemoveZeroSizedCatArgsPass, + ReplaceLogicalNotBooleanWhereWithWherePass, ReplacePT2DequantWithCadenceDequantPass, ReplacePT2QuantWithCadenceQuantPass, ReplaceScalarTensorWithFullPass, ReplaceSqueezeAndUnsqueezeWithViewPass, ) from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion -from executorch.backends.cadence.aot.quantizer.quantizer import ( - CadenceAtenQuantizer, - CadenceQuantizer, -) +from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer from executorch.backends.cadence.aot.utils import model_is_quantized from executorch.backends.transforms.decompose_sdpa import ( DecomposeScaledDotProductAttention, ) +from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge -from pyre_extensions import assert_is_instance from torch._export import capture_pre_autograd_graph from torch.ao.quantization.pt2e.export_utils import model_is_exported from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -63,10 +63,8 @@ def quantize_pt2( converted_model = convert_pt2e(prepared_model) # Get patterns and apply fusion of dq -> op -> q to qop - patterns = [ - assert_is_instance(q, CadenceAtenQuantizer).pattern - for q in quantizer.quantizers - ] + # pyre-ignore[16]: no attribute + patterns = [q.pattern for q in quantizer.quantizers] QuantFusion(patterns)(converted_model) return converted_model @@ -148,8 +146,12 @@ def export_to_cadence( # Run a couple required passes for quant/dequant ops cadence_program_manager = edge_program_manager.transform( [ + InitializePipeline(), RemoveZeroSizedCatArgsPass(), + ReplaceLogicalNotBooleanWhereWithWherePass(), ReplaceScalarTensorWithFullPass(), + RemoveCloneOpsTransform(), + RemoveNopExpandOpPass(), ReplaceSqueezeAndUnsqueezeWithViewPass(), ReplacePT2QuantWithCadenceQuantPass(), ReplacePT2DequantWithCadenceDequantPass(), diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index f79d5f870d..dbfe1e3639 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -62,16 +62,31 @@ - arg_meta: null kernel_name: torch::executor::full_out +- op: mean.out + kernels: + - arg_meta: null + kernel_name: torch::executor::mean_dim_out + - op: mul.out kernels: - arg_meta: null kernel_name: torch::executor::mul_out +- op: mul.Scalar_out + kernels: + - arg_meta: null + kernel_name: torch::executor::mul_scalar_out + - op: permute_copy.out kernels: - arg_meta: null kernel_name: torch::executor::permute_copy_out +- op: rsqrt.out + kernels: + - arg_meta: null + kernel_name: torch::executor::rsqrt_out + - op: sigmoid.out kernels: - arg_meta: null @@ -134,3 +149,8 @@ kernels: - arg_meta: null kernel_name: impl::reference::quantized_relu_out + +func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_matmul_out diff --git a/backends/cadence/aot/passes.py b/backends/cadence/aot/passes.py index ca8a44f00c..db419bfb5e 100644 --- a/backends/cadence/aot/passes.py +++ b/backends/cadence/aot/passes.py @@ -4,18 +4,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, Tuple +# pyre-strict + +from typing import Any, cast, Dict, Sequence, Tuple import torch from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue +from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue +from executorch.exir.passes import dead_code_elimination_pass +from executorch.exir.passes.spec_prop_pass import SpecPropPass from torch._subclasses import FakeTensor from torch.utils._pytree import tree_map_only - -# pyre-strict - # Similar to what's done in executorch/exir/pass_base.py Argument = Any # pyre-ignore @@ -173,3 +174,95 @@ def call_operator( init_args[0] = new_args args = tuple(args) return super().call_operator(op, args, kwargs, meta) + + +class RemoveNopExpandOpPass(ExportPass): + """ + For an expand op, if the operator shape matches the expand shape, then the + expand is a nop. + """ + + def call_operator( + self, + op, # pyre-ignore + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if get_edge_overload_packet(op) not in { + exir_ops.edge.aten.expand_copy, + exir_ops.edge.aten.expand, + }: + return super().call_operator(op, args, kwargs, meta) + + # Parse the args, and check for nop condition + arg0 = cast(ProxyValue, args[0]) + arg1 = cast(Sequence[int], args[1]) + in_tensor = arg0.to_tensor() + if list(in_tensor.shape) == list(arg1): + return arg0 + + return super().call_operator(op, args, kwargs, meta) + + +class ReplaceLogicalNotBooleanWhereWithWherePass(ExportPass): + """ + A where op with a logical_not and a boolean tensor can be replaced + by a where op with flipped inputs and the initial boolean tensor. + """ + + def replace_logical_nop_where_with_where( + self, graph_module: torch.fx.GraphModule + ) -> None: + graph = graph_module.graph + for node in graph.nodes: + # We are only interested in where nodes + if node.target != exir_ops.edge.aten.where.self: + continue + + # If the third arg is not a logical_not, bail. + if node.args[0].target != exir_ops.edge.aten.logical_not.default: + continue + + # Get the third arg node and its input + logical_not_node = node.args[0] + logical_not_input_tensor = ( + logical_not_node.args[0].to_tensor() + if isinstance(logical_not_node.args[0], ProxyValue) + else logical_not_node.args[0] + ) + + # If the logical_not input is not a boolean tensor, bail. + if logical_not_input_tensor.meta["spec"].dtype != torch.bool: + continue + + # Replace the where op with another one, flipping the inputs and using the boolean + # tensor from logical_not. + with graph.inserting_before(node): + linear_node = graph.call_function( + exir_ops.edge.aten.where.self, + args=(logical_not_node.args[0], node.args[2], node.args[1]), + ) + # Replace all the uses + node.replace_all_uses_with(linear_node) + + graph_module.recompile() + graph_module.graph.eliminate_dead_code() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + self.replace_logical_nop_where_with_where(graph_module) + result = super().call(graph_module) + return result + + +class InitializePipeline(ExportPass): + """ + Initialize the Jarvis pipeline. This should invariably be the first pass to + run. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + dead_code_elimination_pass(graph_module) + result = SpecPropPass()(graph_module) + assert result is not None + return result diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 4cd3c6bfb4..61e414ca10 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -26,7 +26,6 @@ is_annotated, no_outside_users, ) -from pyre_extensions import assert_is_instance from torch import fx @@ -100,14 +99,11 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: continue for output, *custom_spec in anchors.output: - assert_is_instance(output, fx.Node).meta["quantization_annotation"] = ( - QuantizationAnnotation( - # pyre-ignore[6]: incompatible parameter type - output_qspec=( - custom_spec[0] if custom_spec else output_act_qspec - ), - _annotated=True, - ) + # pyre-ignore[16]: no attribute + output.meta["quantization_annotation"] = QuantizationAnnotation( + # pyre-ignore[6]: incompatible parameter type + output_qspec=(custom_spec[0] if custom_spec else output_act_qspec), + _annotated=True, ) def annotate_inputs( @@ -118,16 +114,18 @@ def annotate_inputs( spec: Optional[QuantizationSpec], ) -> None: for node, idx, *custom_spec in inputs: - _node = assert_is_instance(node, fx.Node) - annotation = _node.meta.get( + # pyre-ignore[16]: no attribute + annotation = node.meta.get( "quantization_annotation", QuantizationAnnotation(_annotated=True), ) # pyre-ignore[6]: incompatible parameter type - annotation.input_qspec_map[_node.args[idx]] = ( + # pyre-ignore[16]: no attribute + annotation.input_qspec_map[node.args[idx]] = ( custom_spec[0] if custom_spec else spec ) - _node.meta["quantization_annotation"] = annotation + # pyre-ignore[16]: no attribute + node.meta["quantization_annotation"] = annotation annotate_inputs(anchors.inputs, input_act_qspec) annotate_inputs(anchors.weights, weight_qspec) diff --git a/backends/cadence/reference/operators/CMakeLists.txt b/backends/cadence/reference/operators/CMakeLists.txt index c22dc0c997..c81e934850 100644 --- a/backends/cadence/reference/operators/CMakeLists.txt +++ b/backends/cadence/reference/operators/CMakeLists.txt @@ -32,12 +32,15 @@ set(_aten_ops__srcs "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/matmul_ops_util.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/reduce_util.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/pattern/unary_ufunc_realhb_to_floath.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_cat.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mean.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mul.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_permute_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_rsqrt.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sigmoid.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_slice_copy.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_softmax.cpp" @@ -60,7 +63,8 @@ target_include_directories(aten_ops_cadence PUBLIC ${ROOT_DIR}/.. add_library( custom_ops "quantized_linear_out.cpp" "quantized_conv_out.cpp" "quantized_relu_out.cpp" "quantized_layer_norm.cpp" - "quantize_per_tensor.cpp" "dequantize_per_tensor.cpp") + "quantize_per_tensor.cpp" "dequantize_per_tensor.cpp" + "quantized_matmul_out.cpp") target_include_directories(custom_ops PUBLIC ${ROOT_DIR}/.. ${CMAKE_BINARY_DIR} ${_common_include_directories}) diff --git a/backends/cadence/reference/operators/quantized_matmul_out.cpp b/backends/cadence/reference/operators/quantized_matmul_out.cpp index 95df35caba..49dd222a96 100644 --- a/backends/cadence/reference/operators/quantized_matmul_out.cpp +++ b/backends/cadence/reference/operators/quantized_matmul_out.cpp @@ -13,6 +13,9 @@ namespace impl { namespace reference { namespace native { +using Tensor = exec_aten::Tensor; +using RuntimeContext = torch::executor::RuntimeContext; + // The quantized matmul. The quantized matmul accumulates in a wider register, // whose type is TA. template < @@ -50,27 +53,32 @@ __attribute__((noinline)) void qmatmul( } } -template +template void inline _typed_quantized_matmul( const Tensor& X, int64_t X_zero_point, const Tensor& Y, int64_t Y_zero_point, - const c10::optional& bias, + const exec_aten::optional& bias, int64_t out_multiplier, int64_t out_shift, int64_t out_zero_point, bool transposed, Tensor& out) { - ctype* __restrict__ out_data = out.mutable_data_ptr(); - const ctype* __restrict__ X_data = X.const_data_ptr(); - const ctype* __restrict__ Y_data = Y.const_data_ptr(); + size_t batch_size = getLeadingDims(X, X.dim() - 2); + size_t leading_dim = X.size(X.dim() - 2); + size_t out_dim = Y.size(Y.dim() - 1 - transposed); + size_t in_dim = X.size(X.dim() - 1); + + T* __restrict__ out_data = out.mutable_data_ptr(); + const T* __restrict__ X_data = X.const_data_ptr(); + const T* __restrict__ Y_data = Y.const_data_ptr(); for (size_t i = 0; i < batch_size; ++i) { - const ctype* x = X_data + i * leading_dim * in_dim; - const ctype* y = Y_data + i * in_dim * out_dim; - ctype* z = out_data + i * leading_dim * out_dim; + const T* x = X_data + i * leading_dim * in_dim; + const T* y = Y_data + i * in_dim * out_dim; + T* z = out_data + i * leading_dim * out_dim; if (transposed) { - qmatmul( + qmatmul( z, static_cast(out_multiplier), static_cast(out_shift), @@ -83,7 +91,7 @@ void inline _typed_quantized_matmul( in_dim, out_dim); } else { - qmatmul( + qmatmul( z, static_cast(out_multiplier), static_cast(out_shift), @@ -101,24 +109,18 @@ void inline _typed_quantized_matmul( } void quantized_matmul_out( + RuntimeContext& ctx, const Tensor& X, int64_t X_zero_point, const Tensor& Y, int64_t Y_zero_point, - const c10::optional& bias, + const exec_aten::optional& bias, int64_t out_multiplier, int64_t out_shift, int64_t out_zero_point, bool transposed, Tensor& out) { - (void)bias; - - size_t batch_size = getLeadingDims(X, X.dim() - 2); - size_t leading_dim = X.size(X.dim() - 2); - size_t out_dim = Y.size(Y.dim() - 1 - transposed); - size_t in_dim = X.size(X.dim() - 1); - - if (out.ScalarType() == at::ScalarType::Byte) { + if (out.scalar_type() == at::ScalarType::Byte) { _typed_quantized_matmul( X, X_zero_point, @@ -130,7 +132,7 @@ void quantized_matmul_out( out_zero_point, transposed, out); - } else if (out.ScalarType() == at::ScalarType::Char) { + } else if (out.scalar_type() == at::ScalarType::Char) { _typed_quantized_matmul( X, X_zero_point, diff --git a/examples/cadence/models/babyllama.py b/examples/cadence/models/babyllama.py new file mode 100644 index 0000000000..347f9b4a7a --- /dev/null +++ b/examples/cadence/models/babyllama.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Example script for exporting simple models to flatbuffer + +import logging + +from executorch.backends.cadence.aot.ops_registrations import * # noqa + +import torch + +from executorch.backends.cadence.aot.export_example import export_model + +from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer + + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +if __name__ == "__main__": + + args = ModelArgs( + dim=512, + vocab_size=512, + hidden_dim=1024, + n_heads=8, + # use_kv_cache=True, + n_layers=1, + ) + seq = 64 + b = 1 + model = Transformer(args) + example_inputs = (torch.randint(0, 10, [b, seq], dtype=torch.int64),) + + export_model(model, example_inputs)