-
Notifications
You must be signed in to change notification settings - Fork 408
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add portable upsample_bilinear2d kernel (#6923)
Summary: Add a upsample_bilinear2d kernel to the portable kernel library. This implementation re-uses some of the inner logic from the ATen implementation (see Upsample.h and UpsampleKernel.cpp), however I have not ported the outer kernel structure as it relies on TensorIterator and runtime allocation. It may be worth re-visiting this in the future, either by looking at pulling in more of the ATen implementation or adding an optimized variant. Test Plan: Added comprehensive operator-level test coverage for upsample_bilinear2d. ``` buck test //executorch/kernels/test:portable_op_upsample_bilinear2d_test buck test //executorch/kernels/test:aten_op_upsample_bilinear2d_test ``` Reviewed By: manuelcandales Differential Revision: D65756150 Pulled By: GregoryComer
- Loading branch information
1 parent
82763a9
commit 1e4e960
Showing
12 changed files
with
1,146 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <executorch/kernels/portable/cpu/util/upsample_util.h> | ||
#include <executorch/runtime/kernel/kernel_includes.h> | ||
|
||
namespace torch { | ||
namespace executor { | ||
namespace native { | ||
|
||
using exec_aten::ArrayRef; | ||
using exec_aten::optional; | ||
using exec_aten::SizesType; | ||
|
||
namespace { | ||
template <typename CTYPE> | ||
void upsample_bilinear2d_kernel_impl( | ||
const Tensor& in, | ||
bool align_corners, | ||
const float scale_h, | ||
const float scale_w, | ||
Tensor& out) { | ||
const auto in_data = in.const_data_ptr<CTYPE>(); | ||
auto out_data = out.mutable_data_ptr<CTYPE>(); | ||
|
||
auto in_plane = in_data; | ||
for (auto n = 0; n < out.size(0); n++) { | ||
for (auto c = 0; c < out.size(1); c++) { | ||
for (auto h = 0; h < out.size(2); h++) { | ||
// Compute source index and weights. | ||
int64_t in_h1, in_h2; | ||
float weight_h, inv_weight_h; | ||
|
||
compute_source_index_and_lambda( | ||
in_h1, | ||
in_h2, | ||
weight_h, | ||
inv_weight_h, | ||
scale_h, | ||
h, | ||
in.sizes()[2], | ||
out.sizes()[2], | ||
align_corners); | ||
|
||
for (auto w = 0; w < out.size(3); w++) { | ||
int64_t in_w1, in_w2; | ||
float weight_w, inv_weight_w; | ||
|
||
compute_source_index_and_lambda( | ||
in_w1, | ||
in_w2, | ||
weight_w, | ||
inv_weight_w, | ||
scale_w, | ||
w, | ||
in.sizes()[3], | ||
out.sizes()[3], | ||
align_corners); | ||
|
||
const auto top_left = | ||
in_plane[in_h1 * in.strides()[2] + in_w1 * in.strides()[3]]; | ||
const auto top_right = | ||
in_plane[in_h1 * in.strides()[2] + in_w2 * in.strides()[3]]; | ||
const auto bottom_left = | ||
in_plane[in_h2 * in.strides()[2] + in_w1 * in.strides()[3]]; | ||
const auto bottom_right = | ||
in_plane[in_h2 * in.strides()[2] + in_w2 * in.strides()[3]]; | ||
|
||
const auto top = top_left * weight_w + top_right * inv_weight_w; | ||
const auto bottom = | ||
bottom_left * weight_w + bottom_right * inv_weight_w; | ||
const auto val = top * weight_h + bottom * inv_weight_h; | ||
|
||
*out_data = val; | ||
out_data++; | ||
} | ||
} | ||
|
||
in_plane += in.strides()[1]; | ||
} | ||
} | ||
} | ||
} // namespace | ||
|
||
// Signatures are auto-generated, so disable pass-by-value lint. | ||
// NOLINTBEGIN(facebook-hte-ConstantArgumentPassByValue, | ||
// facebook-hte-ParameterMightThrowOnCopy) | ||
Tensor& upsample_bilinear2d_vec_out( | ||
KernelRuntimeContext& ctx, | ||
const Tensor& in, | ||
const exec_aten::OptionalArrayRef<int64_t> output_size, | ||
bool align_corners, | ||
const exec_aten::OptionalArrayRef<double> scale_factors, | ||
Tensor& out) { | ||
// Preconditions (checked in check_..._args): | ||
// In and out tensors have same dtype. | ||
// In and out tensors are rank 4 and have same dim[0] and dim[1]. | ||
// In and out tensors are default dim order (NCHW). | ||
ET_KERNEL_CHECK( | ||
ctx, | ||
check_upsample_bilinear2d_args( | ||
in, output_size, align_corners, scale_factors, out), | ||
InvalidArgument, | ||
out); | ||
|
||
double scale_h, scale_w; | ||
|
||
ET_KERNEL_CHECK_MSG( | ||
ctx, | ||
resize_upsample_2d( | ||
in, output_size, scale_factors, scale_h, scale_w, out) == Error::Ok, | ||
InvalidArgument, | ||
out, | ||
"Failed to resize output tensor"); | ||
|
||
const auto kernel_scale_h = area_pixel_compute_scale<double>( | ||
in.sizes()[2], out.sizes()[2], align_corners, scale_h); | ||
const auto kernel_scale_w = area_pixel_compute_scale<double>( | ||
in.sizes()[3], out.sizes()[3], align_corners, scale_w); | ||
|
||
ET_SWITCH_REAL_TYPES( | ||
in.scalar_type(), ctx, "upsample_bilinear2d.out", CTYPE, [&]() { | ||
upsample_bilinear2d_kernel_impl<CTYPE>( | ||
in, align_corners, kernel_scale_h, kernel_scale_w, out); | ||
}); | ||
|
||
return out; | ||
} | ||
// NOLINTEND(facebook-hte-ConstantArgumentPassByValue, | ||
// facebook-hte-ParameterMightThrowOnCopy) | ||
|
||
} // namespace native | ||
} // namespace executor | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <executorch/kernels/portable/cpu/util/upsample_util.h> | ||
#include <executorch/runtime/core/exec_aten/util/tensor_util.h> | ||
|
||
namespace torch { | ||
namespace executor { | ||
|
||
bool check_upsample_2d_common_args( | ||
const Tensor& in, | ||
const exec_aten::OptionalArrayRef<int64_t>& output_size, | ||
const exec_aten::OptionalArrayRef<double>& scale_factors, | ||
Tensor& out) { | ||
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out)); | ||
ET_LOG_AND_RETURN_IF_FALSE(in.dim() == 4); | ||
ET_LOG_AND_RETURN_IF_FALSE(out.dim() == 4); | ||
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_dim_order(in)); | ||
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_dim_order(out)); | ||
ET_LOG_AND_RETURN_IF_FALSE( | ||
output_size.has_value() ^ scale_factors.has_value()); | ||
if (scale_factors.has_value()) { | ||
ET_LOG_AND_RETURN_IF_FALSE(scale_factors.value().size() == 2); | ||
ET_LOG_AND_RETURN_IF_FALSE(scale_factors.value()[0] > 0); | ||
ET_LOG_AND_RETURN_IF_FALSE(scale_factors.value()[1] > 0); | ||
} else if (output_size.has_value()) { | ||
ET_LOG_AND_RETURN_IF_FALSE(output_size.value().size() == 2); | ||
ET_LOG_AND_RETURN_IF_FALSE(output_size.value()[0] > 0); | ||
ET_LOG_AND_RETURN_IF_FALSE(output_size.value()[1] > 0); | ||
} | ||
|
||
return true; | ||
} | ||
|
||
bool check_upsample_bilinear2d_args( | ||
const Tensor& in, | ||
const exec_aten::OptionalArrayRef<int64_t>& output_size, | ||
ET_UNUSED const bool align_corners, | ||
const exec_aten::OptionalArrayRef<double>& scale_factors, | ||
Tensor& out) { | ||
return check_upsample_2d_common_args(in, output_size, scale_factors, out); | ||
} | ||
|
||
Error resize_upsample_2d( | ||
const Tensor& in, | ||
const exec_aten::OptionalArrayRef<int64_t>& output_size, | ||
const exec_aten::OptionalArrayRef<double>& scale_factors, | ||
double& scale_h_out, | ||
double& scale_w_out, | ||
Tensor& out) { | ||
// Either output_size or scale_factors are provided, not both. This | ||
// is checked in check_..._args. | ||
// Scales are transformed according to align_corners. | ||
std::array<Tensor::SizesType, kTensorDimensionLimit> target_size; | ||
|
||
const auto dim = in.dim(); | ||
std::copy(in.sizes().cbegin(), in.sizes().cend(), target_size.begin()); | ||
|
||
if (scale_factors.has_value()) { | ||
scale_h_out = scale_factors.value()[0]; | ||
scale_w_out = scale_factors.value()[1]; | ||
|
||
target_size[dim - 2] = | ||
static_cast<Tensor::SizesType>(in.sizes()[dim - 2] * scale_h_out); | ||
target_size[dim - 1] = | ||
static_cast<Tensor::SizesType>(in.sizes()[dim - 1] * scale_w_out); | ||
} else if (output_size.has_value()) { | ||
scale_h_out = | ||
static_cast<double>(output_size.value()[0]) / in.sizes()[dim - 2]; | ||
scale_w_out = | ||
static_cast<double>(output_size.value()[1]) / in.sizes()[dim - 1]; | ||
|
||
target_size[dim - 2] = output_size.value()[0]; | ||
target_size[dim - 1] = output_size.value()[1]; | ||
} else { | ||
ET_LOG(Error, "Invalid output_size or scale_factors"); | ||
return Error::InvalidArgument; | ||
} | ||
|
||
ET_CHECK_OR_RETURN_ERROR( | ||
target_size[dim - 2] > 0 && target_size[dim - 1] > 0, | ||
InvalidArgument, | ||
"Upsampled output size must be non-empty, but was %ld x %ld.", | ||
static_cast<long>(target_size[dim - 2]), | ||
static_cast<long>(target_size[dim - 1])); | ||
|
||
return resize_tensor(out, {target_size.data(), static_cast<size_t>(dim)}); | ||
} | ||
|
||
} // namespace executor | ||
} // namespace torch |
Oops, something went wrong.