From 4e44e91a16915f504c701a938bd905271c87e1fc Mon Sep 17 00:00:00 2001 From: David Lin Date: Thu, 19 Dec 2024 07:25:35 -0800 Subject: [PATCH] Add mean.dtype_out op for Ads model Summary: title Differential Revision: D67453766 --- kernels/aten/functions.yaml | 2 ++ kernels/portable/cpu/op_mean.cpp | 8 ++++++++ kernels/portable/functions.yaml | 5 +++++ 3 files changed, 15 insertions(+) diff --git a/kernels/aten/functions.yaml b/kernels/aten/functions.yaml index ebcd86d851..7c54e05331 100644 --- a/kernels/aten/functions.yaml +++ b/kernels/aten/functions.yaml @@ -257,6 +257,8 @@ - op: mean.out +- op: mean.dtype_out + - op: min.dim_min - op: min.unary_out diff --git a/kernels/portable/cpu/op_mean.cpp b/kernels/portable/cpu/op_mean.cpp index aeb0d7f8ca..6730404dde 100644 --- a/kernels/portable/cpu/op_mean.cpp +++ b/kernels/portable/cpu/op_mean.cpp @@ -66,6 +66,14 @@ Tensor& mean_dim_out( return out; } +Tensor& mean_dtype_out( + KernelRuntimeContext& ctx, + const Tensor& in, + optional dtype, + Tensor& out) { + return mean_dim_out(ctx, in, ArrayRef(), false, dtype, out); +} + } // namespace native } // namespace executor } // namespace torch diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml index 0da9917214..7119b7f66e 100644 --- a/kernels/portable/functions.yaml +++ b/kernels/portable/functions.yaml @@ -577,6 +577,11 @@ - arg_meta: null kernel_name: torch::executor::mean_dim_out +- op: mean.dtype_out + kernels: + - arg_meta: null + kernel_name: torch::executor::mean_dtype_out + - op: min.dim_min kernels: - arg_meta: null