Skip to content

Commit

Permalink
cadence quantized_linear_per_tensor_out cpu
Browse files Browse the repository at this point in the history
Differential Revision: D66915708

Pull Request resolved: #7236
  • Loading branch information
kirklandsign authored Dec 10, 2024
1 parent 5161d70 commit a6841d5
Show file tree
Hide file tree
Showing 8 changed files with 923 additions and 99 deletions.
50 changes: 50 additions & 0 deletions backends/cadence/aot/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,41 @@
- arg_meta: null
kernel_name: torch::executor::where_out

- op: transpose_copy.int_out
kernels:
- arg_meta: null
kernel_name: torch::executor::transpose_copy_int_out

- op: eq.Scalar_out
kernels:
- arg_meta: null
kernel_name: torch::executor::eq_scalar_out

- op: logical_not.out
kernels:
- arg_meta: null
kernel_name: torch::executor::logical_not_out

- op: any.out
kernels:
- arg_meta: null
kernel_name: torch::executor::any_out

- op: native_group_norm.out
kernels:
- arg_meta: null
kernel_name: torch::executor::native_group_norm_out

- op: sum.IntList_out
kernels:
- arg_meta: null
kernel_name: torch::executor::sum_dim_out

- op: select_copy.int_out
kernels:
- arg_meta: null
kernel_name: torch::executor::select_copy_int_out

# custom ops
- func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
variants: function
Expand Down Expand Up @@ -183,3 +218,18 @@
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_matmul_out

- func: cadence::quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_linear_per_tensor_out

- func: cadence::im2row.out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, Tensor in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::reference::im2row_out

- func: cadence::quantized_conv.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_conv_per_tensor_out
11 changes: 11 additions & 0 deletions backends/cadence/reference/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ set(_aten_ops__srcs
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_expand_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_gelu.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_empty.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_transpose_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_eq.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_logical_not.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_any.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_native_group_norm.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sum.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_select_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/dtype_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/normalization_ops_util.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/select_copy_util.cpp"
)
add_library(aten_ops_cadence ${_aten_ops__srcs})
target_link_libraries(aten_ops_cadence PUBLIC executorch)
Expand All @@ -78,6 +88,7 @@ add_library(
"quantize_per_tensor.cpp"
"dequantize_per_tensor.cpp"
"quantized_matmul_out.cpp"
"im2row_out.cpp"
)
target_include_directories(
custom_ops PUBLIC ${ROOT_DIR}/.. ${CMAKE_BINARY_DIR}
Expand Down
206 changes: 206 additions & 0 deletions backends/cadence/reference/operators/im2row_out.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

#include <executorch/backends/cadence/reference/operators/operators.h>

#include <algorithm>

namespace impl {
namespace reference {
namespace native {

using ::executorch::aten::IntArrayRef;
using ::executorch::aten::ScalarType;
using ::executorch::aten::Tensor;
using ::executorch::runtime::KernelRuntimeContext;

template <typename T>
__attribute__((always_inline)) void im2row_(
const T* __restrict__ data_im,
const int32_t in_zero_point,
/* input parameters*/
const int32_t channels,
const int32_t height,
const int32_t width,
/* output parameters */
const int32_t out_height,
const int32_t out_width,
/* convolution parameters */
const int32_t kernel_h,
const int32_t kernel_w,
const int32_t pad_h,
const int32_t pad_w,
const int32_t stride_h,
const int32_t stride_w,
const int32_t dilation_h,
const int32_t dilation_w,
T* __restrict__ data_col,
bool channels_last) {
// Consider convolving the input image of dimensions channels * height * width
// (or height * width * channels for NHWC layout) with a filter of dimensions
// channels * kernels_h * kernels_w. Assume that this convolution will produce
// an output of dimensinos out_height x out_width. For each point the output,
// im2row takes the data from the input that is used in the computation of
// that output point, and flattens it into a vector of size channels_col =
// channels * kernel_h * kernel_w. The output of im2row will therefore be a 2D
// array of size (out_height * out_width) x channels_col
const int32_t channels_col = channels * kernel_h * kernel_w;

// If the layout is NHWC, we can copy 'channels' worth of contiguous data
// points when performing im2row.
if (channels_last) {
// Iterate over the output domain
for (int _h = 0; _h < out_height; ++_h) {
for (int _w = 0; _w < out_width; ++_w) {
int32_t i_col = _h * out_width + _w;
// Each point in the output domain is the result of applying a filter of
// size kernel_h x kernel_w x channels on the input. But since channels
// is contiguous, we will not explicitly have a loop for it.
for (int _kh = 0; _kh < kernel_h; ++_kh) {
int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h;
for (int _kw = 0; _kw < kernel_w; ++_kw) {
int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w;

// h_im and w_im are the actual height and width coordinates of the
// input tensor from where we need to copy 'channels' points.
const T* __restrict__ slice_im =
data_im + (h_im * width + w_im) * channels;
T* __restrict__ slice_col = data_col + i_col * channels_col +
(_kh * kernel_w + _kw) * channels;
// If the coordinates were within the input domain, we copy
// 'channels' contiguous values. Otherwise we will fill the output
// with 0's.
if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
std::memcpy(slice_col, slice_im, channels * sizeof(T));
} else {
std::fill_n(slice_col, channels, T(in_zero_point));
}
}
}
}
}
} else {
// Iterate over the output domain
for (int _h = 0; _h < out_height; ++_h) {
for (int _w = 0; _w < out_width; ++_w) {
int32_t i_col = _h * out_width + _w;

// Each point in the output domain is the result of applying a filter
// of size chanenls * kernel_h x kernel_w on the input
for (int _c = 0; _c < channels; ++_c) {
for (int _kh = 0; _kh < kernel_h; ++_kh) {
for (int _kw = 0; _kw < kernel_w; ++_kw) {
// c_col is the linearized access in the channels_col vector.
int32_t c_col = (_c * kernel_h + _kh) * kernel_w + _kw;
// h_im and w_im are the actual height and width coordinates of
// the input tensor that we need to copy to the output.
int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h;
int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w;
// If the current data access is within the input tensor, copy the
// value
data_col[i_col * channels_col + c_col] =
(h_im >= 0 && w_im >= 0 && h_im < height && w_im < width)
? data_im[(_c * height + h_im) * width + w_im]
: static_cast<T>(in_zero_point);
}
}
}
}
}
}
}

void im2row_out(
__ET_UNUSED KernelRuntimeContext& ctx,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef dilation,
IntArrayRef padding,
IntArrayRef stride,
const Tensor& in_zero_point,
bool channel_last,
Tensor& out) {
// Compute the input tensor's dims
bool unit_height = input.dim() == 3;
const int32_t batch_size = input.size(0);
const int32_t in_c =
channel_last ? input.size(3 - unit_height) : input.size(1);
const int32_t in_h =
unit_height ? 1 : (channel_last ? input.size(1) : input.size(2));
const int32_t in_w =
channel_last ? input.size(2 - unit_height) : input.size(3 - unit_height);

// Get the kernel parameters
int32_t kernel_h = kernel_size[0];
int32_t kernel_w = kernel_size[1];
int32_t dilation_h = dilation[0];
int32_t dilation_w = dilation[1];
int32_t pad_h = padding[0];
int32_t pad_w = padding[1];
int32_t stride_h = stride[0];
int32_t stride_w = stride[1];

// If we were to apply a convolution on the input tensor, compute the output
// height and width.
int32_t out_h =
(in_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) / stride_h + 1;
int32_t out_w =
(in_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) / stride_w + 1;

ET_DCHECK_MSG(
(out_h * out_w) == out.size(1), "dimension mismatch for output");
ET_DCHECK_MSG(
(kernel_h * kernel_w * in_c) == out.size(2),
"dimension mismatch for output");

// Check if the input is per-tensor quantized or per-channel quantized. The
// zero point for each batch could differ for per-channel quantized input.
bool per_tensor_quantized = in_zero_point.numel() == 1;

#define typed_im2row(dtype, ctype) \
case ScalarType::dtype: { \
const ctype* __restrict__ in_data = input.const_data_ptr<ctype>(); \
ctype* __restrict__ out_data = out.mutable_data_ptr<ctype>(); \
const int32_t* __restrict__ zero_point = \
in_zero_point.const_data_ptr<int32_t>(); \
int32_t in_plane = in_c * in_h * in_w; \
int32_t out_plane = kernel_h * kernel_w * in_c * out_h * out_w; \
for (size_t n = 0; n < batch_size; ++n) { \
im2row_<ctype>( \
&in_data[n * in_plane], \
per_tensor_quantized ? zero_point[0] : zero_point[n], \
in_c, \
in_h, \
in_w, \
out_h, \
out_w, \
kernel_h, \
kernel_w, \
pad_h, \
pad_w, \
stride_h, \
stride_w, \
dilation_h, \
dilation_w, \
&out_data[n * out_plane], \
channel_last); \
} \
break; \
}

ScalarType dtype = input.scalar_type();
switch (dtype) {
typed_im2row(Float, float);
typed_im2row(Byte, uint8_t);
typed_im2row(Char, int8_t);
default:
ET_DCHECK_MSG(
false,
"im2row not implemented for dtype %s",
torch::executor::toString(dtype));
}
#undef typed_im2row
}

} // namespace native
} // namespace reference
} // namespace impl
57 changes: 57 additions & 0 deletions backends/cadence/reference/operators/operators.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

#pragma once

#include <executorch/runtime/core/array_ref.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <optional>

namespace cadence {
namespace impl {
namespace cpu {
namespace native {
namespace {
using ::executorch::runtime::getLeadingDims;

#define ET_FORALL_CADENCE_QUANTIZED_TYPES(_) \
_(uint8_t, Byte) \
_(int8_t, Char)

inline __attribute__((always_inline)) void linear_(
const ::executorch::aten::Tensor& input,
const ::executorch::aten::Tensor& weight,
const ::executorch::aten::optional<::executorch::aten::Tensor>& bias,
::executorch::aten::Tensor& output) {
const float* __restrict__ input_data = input.const_data_ptr<float>();
const float* __restrict__ weight_data = weight.const_data_ptr<float>();
const float* __restrict__ bias_data = bias.value().const_data_ptr<float>();
float* __restrict__ output_data = output.mutable_data_ptr<float>();

// input comes in shape [batch_size, in_dim]
// weight comes in shape [out_dim, in_dim]
// output comes in empty with shape [batch_size, out_dim]
// Perform matrix multiply (M x N) x (N x P) => M x P
int64_t M = weight.size(0); // = out_dim
int64_t N = weight.size(1); // = in_dim

// Given an N-dimensional input [d0, d1, d2, ..., d_{N-2}, d_{N-1}], the
// leading dimensions is d0 * d1 * ... * d_{N-2}
int64_t leading_dims = getLeadingDims(input, input.dim() - 1);

for (int i = 0; i < leading_dims; ++i) {
for (int j = 0; j < M; ++j) {
float sum = bias_data[j];
for (int k = 0; k < N; ++k) {
sum += input_data[i * N + k] * weight_data[j * N + k];
}
output_data[i * M + j] = sum;
}
}
}

} // namespace
} // namespace native
} // namespace cpu
} // namespace impl
} // namespace cadence
Loading

0 comments on commit a6841d5

Please sign in to comment.