diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index e7c16d0031..f1a5b6a50b 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -142,6 +142,41 @@ - arg_meta: null kernel_name: torch::executor::where_out +- op: transpose_copy.int_out + kernels: + - arg_meta: null + kernel_name: torch::executor::transpose_copy_int_out + +- op: eq.Scalar_out + kernels: + - arg_meta: null + kernel_name: torch::executor::eq_scalar_out + +- op: logical_not.out + kernels: + - arg_meta: null + kernel_name: torch::executor::logical_not_out + +- op: any.out + kernels: + - arg_meta: null + kernel_name: torch::executor::any_out + +- op: native_group_norm.out + kernels: + - arg_meta: null + kernel_name: torch::executor::native_group_norm_out + +- op: sum.IntList_out + kernels: + - arg_meta: null + kernel_name: torch::executor::sum_dim_out + +- op: select_copy.int_out + kernels: + - arg_meta: null + kernel_name: torch::executor::select_copy_int_out + # custom ops - func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) variants: function @@ -183,3 +218,18 @@ kernels: - arg_meta: null kernel_name: impl::reference::quantized_matmul_out + +- func: cadence::quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_linear_per_tensor_out + +- func: cadence::im2row.out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, Tensor in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::im2row_out + +- func: cadence::quantized_conv.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_conv_per_tensor_out diff --git a/backends/cadence/reference/operators/CMakeLists.txt b/backends/cadence/reference/operators/CMakeLists.txt index c40d3ff66b..a2d51af2c0 100644 --- a/backends/cadence/reference/operators/CMakeLists.txt +++ b/backends/cadence/reference/operators/CMakeLists.txt @@ -55,6 +55,16 @@ set(_aten_ops__srcs "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_expand_copy.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_gelu.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_empty.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_transpose_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_eq.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_logical_not.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_any.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_native_group_norm.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sum.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_select_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/dtype_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/normalization_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/select_copy_util.cpp" ) add_library(aten_ops_cadence ${_aten_ops__srcs}) target_link_libraries(aten_ops_cadence PUBLIC executorch) @@ -78,6 +88,7 @@ add_library( "quantize_per_tensor.cpp" "dequantize_per_tensor.cpp" "quantized_matmul_out.cpp" + "im2row_out.cpp" ) target_include_directories( custom_ops PUBLIC ${ROOT_DIR}/.. ${CMAKE_BINARY_DIR} diff --git a/backends/cadence/reference/operators/im2row_out.cpp b/backends/cadence/reference/operators/im2row_out.cpp new file mode 100644 index 0000000000..dd539b6f9b --- /dev/null +++ b/backends/cadence/reference/operators/im2row_out.cpp @@ -0,0 +1,206 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include + +#include + +namespace impl { +namespace reference { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +template +__attribute__((always_inline)) void im2row_( + const T* __restrict__ data_im, + const int32_t in_zero_point, + /* input parameters*/ + const int32_t channels, + const int32_t height, + const int32_t width, + /* output parameters */ + const int32_t out_height, + const int32_t out_width, + /* convolution parameters */ + const int32_t kernel_h, + const int32_t kernel_w, + const int32_t pad_h, + const int32_t pad_w, + const int32_t stride_h, + const int32_t stride_w, + const int32_t dilation_h, + const int32_t dilation_w, + T* __restrict__ data_col, + bool channels_last) { + // Consider convolving the input image of dimensions channels * height * width + // (or height * width * channels for NHWC layout) with a filter of dimensions + // channels * kernels_h * kernels_w. Assume that this convolution will produce + // an output of dimensinos out_height x out_width. For each point the output, + // im2row takes the data from the input that is used in the computation of + // that output point, and flattens it into a vector of size channels_col = + // channels * kernel_h * kernel_w. The output of im2row will therefore be a 2D + // array of size (out_height * out_width) x channels_col + const int32_t channels_col = channels * kernel_h * kernel_w; + + // If the layout is NHWC, we can copy 'channels' worth of contiguous data + // points when performing im2row. + if (channels_last) { + // Iterate over the output domain + for (int _h = 0; _h < out_height; ++_h) { + for (int _w = 0; _w < out_width; ++_w) { + int32_t i_col = _h * out_width + _w; + // Each point in the output domain is the result of applying a filter of + // size kernel_h x kernel_w x channels on the input. But since channels + // is contiguous, we will not explicitly have a loop for it. + for (int _kh = 0; _kh < kernel_h; ++_kh) { + int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h; + for (int _kw = 0; _kw < kernel_w; ++_kw) { + int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w; + + // h_im and w_im are the actual height and width coordinates of the + // input tensor from where we need to copy 'channels' points. + const T* __restrict__ slice_im = + data_im + (h_im * width + w_im) * channels; + T* __restrict__ slice_col = data_col + i_col * channels_col + + (_kh * kernel_w + _kw) * channels; + // If the coordinates were within the input domain, we copy + // 'channels' contiguous values. Otherwise we will fill the output + // with 0's. + if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + std::memcpy(slice_col, slice_im, channels * sizeof(T)); + } else { + std::fill_n(slice_col, channels, T(in_zero_point)); + } + } + } + } + } + } else { + // Iterate over the output domain + for (int _h = 0; _h < out_height; ++_h) { + for (int _w = 0; _w < out_width; ++_w) { + int32_t i_col = _h * out_width + _w; + + // Each point in the output domain is the result of applying a filter + // of size chanenls * kernel_h x kernel_w on the input + for (int _c = 0; _c < channels; ++_c) { + for (int _kh = 0; _kh < kernel_h; ++_kh) { + for (int _kw = 0; _kw < kernel_w; ++_kw) { + // c_col is the linearized access in the channels_col vector. + int32_t c_col = (_c * kernel_h + _kh) * kernel_w + _kw; + // h_im and w_im are the actual height and width coordinates of + // the input tensor that we need to copy to the output. + int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h; + int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w; + // If the current data access is within the input tensor, copy the + // value + data_col[i_col * channels_col + c_col] = + (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) + ? data_im[(_c * height + h_im) * width + w_im] + : static_cast(in_zero_point); + } + } + } + } + } + } +} + +void im2row_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride, + const Tensor& in_zero_point, + bool channel_last, + Tensor& out) { + // Compute the input tensor's dims + bool unit_height = input.dim() == 3; + const int32_t batch_size = input.size(0); + const int32_t in_c = + channel_last ? input.size(3 - unit_height) : input.size(1); + const int32_t in_h = + unit_height ? 1 : (channel_last ? input.size(1) : input.size(2)); + const int32_t in_w = + channel_last ? input.size(2 - unit_height) : input.size(3 - unit_height); + + // Get the kernel parameters + int32_t kernel_h = kernel_size[0]; + int32_t kernel_w = kernel_size[1]; + int32_t dilation_h = dilation[0]; + int32_t dilation_w = dilation[1]; + int32_t pad_h = padding[0]; + int32_t pad_w = padding[1]; + int32_t stride_h = stride[0]; + int32_t stride_w = stride[1]; + + // If we were to apply a convolution on the input tensor, compute the output + // height and width. + int32_t out_h = + (in_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) / stride_h + 1; + int32_t out_w = + (in_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) / stride_w + 1; + + ET_DCHECK_MSG( + (out_h * out_w) == out.size(1), "dimension mismatch for output"); + ET_DCHECK_MSG( + (kernel_h * kernel_w * in_c) == out.size(2), + "dimension mismatch for output"); + + // Check if the input is per-tensor quantized or per-channel quantized. The + // zero point for each batch could differ for per-channel quantized input. + bool per_tensor_quantized = in_zero_point.numel() == 1; + +#define typed_im2row(dtype, ctype) \ + case ScalarType::dtype: { \ + const ctype* __restrict__ in_data = input.const_data_ptr(); \ + ctype* __restrict__ out_data = out.mutable_data_ptr(); \ + const int32_t* __restrict__ zero_point = \ + in_zero_point.const_data_ptr(); \ + int32_t in_plane = in_c * in_h * in_w; \ + int32_t out_plane = kernel_h * kernel_w * in_c * out_h * out_w; \ + for (size_t n = 0; n < batch_size; ++n) { \ + im2row_( \ + &in_data[n * in_plane], \ + per_tensor_quantized ? zero_point[0] : zero_point[n], \ + in_c, \ + in_h, \ + in_w, \ + out_h, \ + out_w, \ + kernel_h, \ + kernel_w, \ + pad_h, \ + pad_w, \ + stride_h, \ + stride_w, \ + dilation_h, \ + dilation_w, \ + &out_data[n * out_plane], \ + channel_last); \ + } \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + typed_im2row(Float, float); + typed_im2row(Byte, uint8_t); + typed_im2row(Char, int8_t); + default: + ET_DCHECK_MSG( + false, + "im2row not implemented for dtype %s", + torch::executor::toString(dtype)); + } +#undef typed_im2row +} + +} // namespace native +} // namespace reference +} // namespace impl diff --git a/backends/cadence/reference/operators/operators.h b/backends/cadence/reference/operators/operators.h new file mode 100644 index 0000000000..0ff4639255 --- /dev/null +++ b/backends/cadence/reference/operators/operators.h @@ -0,0 +1,57 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include +#include + +namespace cadence { +namespace impl { +namespace cpu { +namespace native { +namespace { +using ::executorch::runtime::getLeadingDims; + +#define ET_FORALL_CADENCE_QUANTIZED_TYPES(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) + +inline __attribute__((always_inline)) void linear_( + const ::executorch::aten::Tensor& input, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::optional<::executorch::aten::Tensor>& bias, + ::executorch::aten::Tensor& output) { + const float* __restrict__ input_data = input.const_data_ptr(); + const float* __restrict__ weight_data = weight.const_data_ptr(); + const float* __restrict__ bias_data = bias.value().const_data_ptr(); + float* __restrict__ output_data = output.mutable_data_ptr(); + + // input comes in shape [batch_size, in_dim] + // weight comes in shape [out_dim, in_dim] + // output comes in empty with shape [batch_size, out_dim] + // Perform matrix multiply (M x N) x (N x P) => M x P + int64_t M = weight.size(0); // = out_dim + int64_t N = weight.size(1); // = in_dim + + // Given an N-dimensional input [d0, d1, d2, ..., d_{N-2}, d_{N-1}], the + // leading dimensions is d0 * d1 * ... * d_{N-2} + int64_t leading_dims = getLeadingDims(input, input.dim() - 1); + + for (int i = 0; i < leading_dims; ++i) { + for (int j = 0; j < M; ++j) { + float sum = bias_data[j]; + for (int k = 0; k < N; ++k) { + sum += input_data[i * N + k] * weight_data[j * N + k]; + } + output_data[i * M + j] = sum; + } + } +} + +} // namespace +} // namespace native +} // namespace cpu +} // namespace impl +} // namespace cadence diff --git a/backends/cadence/reference/operators/quantized_conv_out.cpp b/backends/cadence/reference/operators/quantized_conv_out.cpp index de19f3ef43..5a7af85809 100644 --- a/backends/cadence/reference/operators/quantized_conv_out.cpp +++ b/backends/cadence/reference/operators/quantized_conv_out.cpp @@ -1,21 +1,16 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. #include - -#include +#include namespace impl { namespace reference { namespace native { -using executorch::aten::Tensor; -using executorch::runtime::KernelRuntimeContext; +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; // This implements a generic 2d conv kernel that operates on raw pointers. // The version handles both quantized and fp32 convolutions. @@ -23,7 +18,12 @@ using executorch::runtime::KernelRuntimeContext; // The weight is of shape [oc x wc x wh x ww], where wc == c // The output is of shape [n x oc x oh x ow] // The bias is of shape [oc] -template +template < + typename IT = float, + typename WT = IT, + typename BT = IT, + typename OT = IT, + bool quantized = false> __attribute__((noinline)) void conv2d_nchw_core_generic( // All the arrays const IT* __restrict__ p_in, @@ -56,11 +56,10 @@ __attribute__((noinline)) void conv2d_nchw_core_generic( // input zero point IT in_zero_point = 0, // weight zero point - const int32_t* __restrict__ weight_zero_point = nullptr, - const float* __restrict__ bias_scale = nullptr, + int32_t weight_zero_point = 0, + float bias_scale = 1, float out_scale = 1, - OT out_zero_point = 0, - bool per_tensor_quantized = true) { + OT out_zero_point = 0) { float inv_out_scale = 1. / out_scale; bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; @@ -106,7 +105,7 @@ __attribute__((noinline)) void conv2d_nchw_core_generic( int woff = _wh * ww + _ww; float lhs = in_plane[ioff] - in_zero_point; float rhs = weight_plane[woff] - - (quantized ? weight_zero_point[0] : 0); + (quantized ? weight_zero_point : 0); acc += lhs * rhs; } } @@ -126,7 +125,7 @@ __attribute__((noinline)) void conv2d_nchw_core_generic( int woff = _wh * ww + _ww; float lhs = in_plane[ioff] - in_zero_point; float rhs = weight_plane[woff] - - (quantized ? weight_zero_point[0] : 0); + (quantized ? weight_zero_point : 0); acc += lhs * rhs; } } @@ -134,11 +133,10 @@ __attribute__((noinline)) void conv2d_nchw_core_generic( } } if (quantized) { - float val = - (per_tensor_quantized ? bias_scale[0] : bias_scale[_oc]) * - acc; + float val = bias_scale * acc; out_plane[_oh * ow + _ow] = - kernels::quantize(val, inv_out_scale, out_zero_point); + ::impl::reference::kernels::quantize( + val, inv_out_scale, out_zero_point); } else { out_plane[_oh * ow + _ow] = acc; } @@ -149,27 +147,149 @@ __attribute__((noinline)) void conv2d_nchw_core_generic( } } +template < + typename IT = float, + typename WT = IT, + typename BT = IT, + typename OT = IT, + bool quantized = false> +__attribute__((noinline)) void conv2d_nhwc_core_generic( + // All the arrays + const IT* __restrict__ p_in, + const WT* __restrict__ p_weight, + const BT* __restrict__ p_bias, + OT* __restrict__ p_out, + // The array sizes + int32_t n, + int32_t h, + int32_t w, + int32_t c, + int32_t oc, + int32_t wh, + int32_t ww, + int32_t wc, + int32_t oh, + int32_t ow, + // Stride + int16_t s0, + int16_t s1, + // Padding + int16_t p0, + int16_t p1, + // Dilation + int16_t d0, + int16_t d1, + // Group for depthwise conv + int16_t groups, + // Optional args that are only relevant for quantized convolution + // input zero point + IT in_zero_point = 0, + // weight zero point + int32_t weight_zero_point = 0, + float bias_scale = 1, + float out_scale = 1, + OT out_zero_point = 0) { + float inv_out_scale = 1. / out_scale; + bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; + + // Compute the number of in and out channels per group + const int ocpg = oc / groups; + const int icpg = c / groups; + + // Iterate over all the output batches (i.e., n) + for (int _n = 0; _n < n; ++_n) { + const IT* in_batch = p_in + _n * h * w * c; + OT* out_batch = p_out + _n * oh * ow * oc; + for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { + for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { + OT* out_line = out_batch + (_oh * ow + _ow) * oc; + // Compute separable convolution for each group + for (int _g = 0; _g < groups; ++_g) { + // Identify the input and output channels involved in the computation + // of this group + int sic = _g * icpg; + int soc = _g * ocpg; + // Populate all the output channels in the group + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + const WT* weight_batch = p_weight + _oc * wh * ww * wc; + // We compute one output channel at a time. The computation can be + // thought of as a stencil computation: we iterate over an input of + // size h x w x icpg, with a stencil of size wh x ww x icpg, to + // compute an output channel of size oh x ow x 1. + float acc = p_bias[_oc]; + // Below is the stencil computation that performs the hadamard + // product+accumulation of each input channel (contributing to + // the output channel being computed) with the corresponding + // weight channel. If the padding is 0, and dilation is 1, then + // we can remove the unnecessary checks, and simplify the code + // so that it can be vectorized by Tensilica compiler.x`` + if (zero_pad_unit_dilation) { + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + const IT* in_line = + in_batch + (_h + _wh) * w * c + (_w + _ww) * c; + const WT* weight_line = + weight_batch + _wh * ww * wc + _ww * wc; + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + float lhs = in_line[_ic] - in_zero_point; + float rhs = weight_line[_ic - sic] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } else { + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + if (((_h + d0 * _wh - p0) >= 0) && + ((_h + d0 * _wh - p0) < h) && + ((_w + d1 * _ww - p1) >= 0) && + ((_w + d1 * _ww - p1 < w))) { + const IT* in_line = in_batch + + (_h + d0 * _wh - p0) * w * c + (_w + d1 * _ww - p1) * c; + const WT* weight_line = + weight_batch + _wh * ww * wc + _ww * wc; + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + float lhs = in_line[_ic] - in_zero_point; + float rhs = weight_line[_ic - sic] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } + } + if (quantized) { + float val = bias_scale * acc; + out_line[_oc] = ::impl::reference::kernels::quantize( + val, inv_out_scale, out_zero_point); + } else { + out_line[_oc] = acc; + } + } + } + } + } + } +} + // The quantized convolution kernel. in_scale and weight_scale are implicit in // bias_scale, since it is a product of the two. The kernel will branch to // quantized::conv1d or quantized::conv2d based on the dimensionality of // activation tensor. -void quantized_conv_out( - KernelRuntimeContext& ctx, +void quantized_conv_nchw( const Tensor& input, const Tensor& weight, const Tensor& bias, - executorch::aten::IntArrayRef stride, - executorch::aten::IntArrayRef padding, - executorch::aten::IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - const Tensor& weight_zero_point, - const Tensor& bias_scale, - double output_scale, - int64_t output_zero_point, - const Tensor& out_multiplier, - const Tensor& out_shift, - bool channel_last, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, Tensor& out) { bool conv1d = input.dim() == 3; // input = [n, c, h, w] @@ -186,76 +306,224 @@ void quantized_conv_out( const int oh = conv1d ? 1 : out.size(2); const int ow = conv1d ? out.size(2) : out.size(3); - // Bool flag to check if weight tensor is quantized per-tensor or - // per-channel - bool per_tensor_quantized = bias_scale.numel() == 1; +#define typed_quantized_conv2d_nchw(ctype, dtype) \ + case ScalarType::dtype: { \ + conv2d_nchw_core_generic( \ + input.const_data_ptr(), \ + weight.const_data_ptr(), \ + bias.const_data_ptr(), \ + out.mutable_data_ptr(), \ + n, \ + c, \ + h, \ + w, \ + oc, \ + wc, \ + wh, \ + ww, \ + oh, \ + ow, \ + stride[0], \ + stride[1], \ + padding[0], \ + padding[1], \ + dilation[0], \ + dilation[1], \ + groups, \ + in_zero_point, \ + weight_zero_point, \ + bias_scale, \ + output_scale, \ + (ctype)output_zero_point); \ + break; \ + } + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_conv2d_nchw); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_conv2d_nchw +} + +void quantized_conv_nhwc( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + // input = [n, h, w, c] + const int n = input.size(0); + const int h = conv1d ? 1 : input.size(1); + const int w = conv1d ? input.size(1) : input.size(2); + const int c = conv1d ? input.size(2) : input.size(3); + // weight = [oc, wh, ww, wc] + const int oc = weight.size(0); + const int wh = conv1d ? 1 : weight.size(1); + const int ww = conv1d ? weight.size(1) : weight.size(2); + const int wc = conv1d ? weight.size(2) : weight.size(3); + // output = [n, oh, ow, oc] + const int oh = conv1d ? 1 : out.size(1); + const int ow = conv1d ? out.size(1) : out.size(2); + +#define typed_quantized_conv2d_nhwc(ctype, dtype) \ + case ScalarType::dtype: { \ + conv2d_nhwc_core_generic( \ + input.const_data_ptr(), \ + weight.const_data_ptr(), \ + bias.const_data_ptr(), \ + out.mutable_data_ptr(), \ + n, \ + h, \ + w, \ + c, \ + oc, \ + wh, \ + ww, \ + wc, \ + oh, \ + ow, \ + stride[0], \ + stride[1], \ + padding[0], \ + padding[1], \ + dilation[0], \ + dilation[1], \ + groups, \ + in_zero_point, \ + weight_zero_point, \ + bias_scale, \ + output_scale, \ + (ctype)output_zero_point); \ + break; \ + } + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_conv2d_nhwc); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_conv2d_nhwc +} - if (out.scalar_type() == exec_aten::ScalarType::Byte) { - conv2d_nchw_core_generic( - input.const_data_ptr(), - weight.const_data_ptr(), - bias.const_data_ptr(), - out.mutable_data_ptr(), - n, - c, - h, - w, - oc, - wc, - wh, - ww, - oh, - ow, - stride[0], - stride[1], - padding[0], - padding[1], - dilation[0], - dilation[1], +void quantized_conv_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED const Tensor& out_multiplier, + __ET_UNUSED const Tensor& out_shift, + bool channel_last, + Tensor& out) { + const float bias_scale_float = bias_scale.const_data_ptr()[0]; + const int32_t weight_zero_point_int = + weight_zero_point.const_data_ptr()[0]; + if (channel_last) { + quantized_conv_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, groups, in_zero_point, - weight_zero_point.const_data_ptr(), - bias_scale.const_data_ptr(), + weight_zero_point_int, + bias_scale_float, output_scale, - (uint8_t)output_zero_point, - per_tensor_quantized); - } else if (out.scalar_type() == exec_aten::ScalarType::Char) { - conv2d_nchw_core_generic( - input.const_data_ptr(), - weight.const_data_ptr(), - bias.const_data_ptr(), - out.mutable_data_ptr(), - n, - c, - h, - w, - oc, - wc, - wh, - ww, - oh, - ow, - stride[0], - stride[1], - padding[0], - padding[1], - dilation[0], - dilation[1], + output_zero_point, + out); + } else { + quantized_conv_nchw( + input, + weight, + bias, + stride, + padding, + dilation, groups, in_zero_point, - weight_zero_point.const_data_ptr(), - bias_scale.const_data_ptr(), + weight_zero_point_int, + bias_scale_float, output_scale, - (int8_t)output_zero_point, - per_tensor_quantized); + output_zero_point, + out); + } +} + +void quantized_conv_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + bool channel_last, + Tensor& out) { + if (channel_last) { + quantized_conv_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); } else { - ET_CHECK_MSG( - false, - "Unhandled input dtype %hhd", - static_cast(input.scalar_type())); + quantized_conv_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); } } -}; // namespace native -}; // namespace reference -}; // namespace impl +} // namespace native +} // namespace reference +} // namespace impl diff --git a/backends/cadence/reference/operators/quantized_linear_out.cpp b/backends/cadence/reference/operators/quantized_linear_out.cpp index 7bb1bf6fb4..4f7ca9cc3c 100644 --- a/backends/cadence/reference/operators/quantized_linear_out.cpp +++ b/backends/cadence/reference/operators/quantized_linear_out.cpp @@ -6,7 +6,8 @@ * LICENSE file in the root directory of this source tree. */ -#include +#include +#include #include namespace impl { @@ -85,6 +86,7 @@ void quantized_linear_out( int64_t out_zero_point, __ET_UNUSED const executorch::aten::optional& offset, Tensor& out) { + // TODO: refactor to use switch case as quantized_linear_per_tensor_out if (out.scalar_type() == executorch::aten::ScalarType::Byte) { _typed_quantized_linear( src, @@ -115,6 +117,43 @@ void quantized_linear_out( } } +void quantized_linear_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& src, + const Tensor& weight, + const Tensor& bias, + const int64_t src_zero_point, + const int64_t weight_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + const int64_t out_zero_point, + __ET_UNUSED const executorch::aten::optional& offset, + Tensor& out) { +#define typed_quantized_linear_per_tensor(ctype, dtype) \ + case executorch::aten::ScalarType::dtype: { \ + quantized_linear_per_tensor_( \ + src, \ + weight, \ + bias, \ + src_zero_point, \ + weight_zero_point, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + executorch::aten::ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear_per_tensor); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", executorch::runtime::toString(dtype)); + } +#undef typed_quantized_linear_per_tensor +} + }; // namespace native }; // namespace reference }; // namespace impl diff --git a/backends/cadence/reference/operators/quantized_ops.h b/backends/cadence/reference/operators/quantized_ops.h new file mode 100644 index 0000000000..66545c8e58 --- /dev/null +++ b/backends/cadence/reference/operators/quantized_ops.h @@ -0,0 +1,190 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include + +template +inline __attribute__((always_inline)) void quantized_linear_per_tensor_( + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + const int64_t src_zero_point, + const int64_t weight_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + const int64_t out_zero_point, + ::executorch::aten::Tensor& out) { + // input comes in shape [leading_dims, in_dim] + // weight comes in shape [out_dim, in_dim] + // output comes in empty with shape [leading_dims, out_dim] + // Perform matrix multiply (M x N) x (N x P)' => M x P + const int64_t leading_dims = + executorch::runtime::getLeadingDims(src, src.dim() - 1); + const int64_t out_dim = weight.size(0); // = out_dim + const int64_t in_dim = weight.size(1); // = in_dim + + const T* __restrict__ in_data = src.const_data_ptr(); + const T* __restrict__ weight_data = weight.const_data_ptr(); + const int32_t* __restrict__ bias_data = bias.const_data_ptr(); + T* __restrict__ out_data = out.mutable_data_ptr(); + + // Compute the requant_scale from out_multiplier and out_shift + const float requant_scale = + -out_multiplier * 1.0 / (1 << 31) * pow(2, out_shift); + + for (size_t i = 0; i < leading_dims; ++i) { + for (size_t j = 0; j < out_dim; ++j) { + int32_t sum = bias_data[j]; + for (size_t k = 0; k < in_dim; ++k) { + int32_t x = (int32_t)in_data[i * in_dim + k] - src_zero_point; + int32_t w = + (int32_t)weight_data[j * in_dim + k] - (int32_t)weight_zero_point; + sum += x * w; + } + out_data[i * out_dim + j] = ::impl::reference::kernels::quantize( + sum, requant_scale, out_zero_point); + } + } +} + +template +inline __attribute__((always_inline)) void quantized_linear_per_tensor_( + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + int64_t src_zero_point, + const ::executorch::aten::Tensor& weight_zero_point_t, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + ::executorch::aten::Tensor& out) { + // Get the zero_point of weight. + int32_t weight_zero_point = weight_zero_point_t.const_data_ptr()[0]; + quantized_linear_per_tensor_( + src, + weight, + bias, + src_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + out); +} + +template +inline __attribute__((always_inline)) void quantized_linear_per_channel_( + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + int64_t src_zero_point, + int64_t weight_zero_point, + const ::executorch::aten::Tensor& out_multiplier, + const ::executorch::aten::Tensor& out_shift, + int64_t out_zero_point, + ::executorch::aten::Tensor& out) { + // input comes in shape [leading_dims, in_dim] + // weight comes in shape [out_dim, in_dim] + // output comes in empty with shape [leading_dims, out_dim] + // Perform matrix multiply (M x N) x (N x P)' => M x P + int64_t leading_dims = + executorch::runtime::getLeadingDims(src, src.dim() - 1); + const int64_t out_dim = weight.size(0); // = out_dim + const int64_t in_dim = weight.size(1); // = in_dim + + const T* __restrict__ in_data = src.const_data_ptr(); + const T* __restrict__ weight_data = weight.const_data_ptr(); + const int32_t* __restrict__ bias_data = bias.const_data_ptr(); + T* __restrict__ out_data = out.mutable_data_ptr(); + const int32_t* __restrict__ out_multiplier_data = + out_multiplier.const_data_ptr(); + const int32_t* __restrict__ out_shift_data = + out_shift.const_data_ptr(); + + for (size_t i = 0; i < leading_dims; ++i) { + for (size_t j = 0; j < out_dim; ++j) { + int32_t sum = bias_data[j]; + for (size_t k = 0; k < in_dim; ++k) { + int32_t x = (int32_t)in_data[i * in_dim + k] - src_zero_point; + int32_t w = + (int32_t)weight_data[j * in_dim + k] - (int32_t)weight_zero_point; + sum += x * w; + } + // Compute the out_scale from out_multiplier and out_shift + const float out_scale = + -out_multiplier_data[j] * 1.0 / (1 << 31) * pow(2, out_shift_data[j]); + out_data[i * out_dim + j] = ::impl::reference::kernels::quantize( + sum, out_scale, out_zero_point); + } + } +} + +template +inline __attribute__((always_inline)) void quantized_linear_( + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + int64_t src_zero_point, + int64_t weight_zero_point, + const ::executorch::aten::Tensor& out_multiplier, + const ::executorch::aten::Tensor& out_shift, + int64_t out_zero_point, + ::executorch::aten::Tensor& out) { + if (out_multiplier.numel() == 1) { + // Use per-tensor quantization kernel. + const int32_t* __restrict__ out_multiplier_data = + out_multiplier.const_data_ptr(); + const int32_t* __restrict__ out_shift_data = + out_shift.const_data_ptr(); + quantized_linear_per_tensor_( + src, + weight, + bias, + src_zero_point, + weight_zero_point, + out_multiplier_data[0], + out_shift_data[0], + out_zero_point, + out); + return; + } + + // Use per-channel quantization kernel. + quantized_linear_per_channel_( + src, + weight, + bias, + src_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + out); +} + +template +inline __attribute__((always_inline)) void quantized_linear_( + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + int64_t src_zero_point, + const ::executorch::aten::Tensor& weight_zero_point_t, + const ::executorch::aten::Tensor& out_multiplier, + const ::executorch::aten::Tensor& out_shift, + int64_t out_zero_point, + ::executorch::aten::Tensor& out) { + // Get the zero_point of weight. + int32_t weight_zero_point = weight_zero_point_t.const_data_ptr()[0]; + quantized_linear_( + src, + weight, + bias, + src_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + out); +} diff --git a/backends/cadence/reference/operators/targets.bzl b/backends/cadence/reference/operators/targets.bzl index 347d476239..488aeebb82 100644 --- a/backends/cadence/reference/operators/targets.bzl +++ b/backends/cadence/reference/operators/targets.bzl @@ -7,6 +7,9 @@ def define_common_targets(): srcs = glob([ "*.cpp", ]), + exported_headers =glob([ + "*.h", + ]), platforms = CXX, deps = [ "//executorch/kernels/portable/cpu/util:broadcast_util",