Skip to content

Commit

Permalink
Make filling output type generic
Browse files Browse the repository at this point in the history
  • Loading branch information
mitruska committed Nov 6, 2024
1 parent 3b9a845 commit 864bc49
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions src/plugins/intel_cpu/src/nodes/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2271,17 +2271,16 @@ void Reduce::execute(dnnl::stream strm) {
auto dstMemPtr = getDstMemoryAtPort(0);
auto srcMemPtr = getSrcMemoryAtPort(REDUCE_DATA);

const auto src_shape = getSrcMemoryAtPort(REDUCE_DATA)->getStaticDims();
const uint8_t *src_data = srcMemPtr->getDataAs<const uint8_t>();
uint8_t *dst_data = dstMemPtr->getDataAs<uint8_t>();

const auto& src_shape = getSrcMemoryAtPort(REDUCE_DATA)->getStaticDims();
if ((shape_size(src_shape) == 0 || srcMemPtr->getSize() == 0) && dstMemPtr->getSize() > 0) {
// If input is empty fill ouptut with zero
auto dst_shape = getDstMemoryAtPort(0)->getStaticDims();
std::fill_n(dstMemPtr->getDataAs<float>(), shape_size(dst_shape), 0.f);
std::fill_n(dst_data, dstMemPtr->getSize(), uint8_t{0});
return;
}

const uint8_t *src_data = srcMemPtr->getDataAs<const uint8_t>();
uint8_t *dst_data = dstMemPtr->getDataAs<uint8_t>();

if (jit_mode) {
if (is_hybrid_layout) {
dst_data = reinterpret_cast<uint8_t *>(prc_mem.get_data_handle());
Expand Down

0 comments on commit 864bc49

Please sign in to comment.