From d29d8d3dff56cdfb022a6eca749ebe3cdcd12598 Mon Sep 17 00:00:00 2001 From: "Wang, Zhitao" Date: Thu, 26 Dec 2024 11:20:09 +0000 Subject: [PATCH] graph: backend: dnnl: fix shape check for per-channel dynamic quant --- src/graph/backend/dnnl/passes/lower.cpp | 17 ++++++++------- src/graph/interface/op_def_constraint.cpp | 26 ++++++++++++++++++----- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/src/graph/backend/dnnl/passes/lower.cpp b/src/graph/backend/dnnl/passes/lower.cpp index dca5d32909e..e4491dd46ff 100644 --- a/src/graph/backend/dnnl/passes/lower.cpp +++ b/src/graph/backend/dnnl/passes/lower.cpp @@ -616,14 +616,7 @@ static status_t dynamic_dequant_handler( rewriter.to_insert(mul_scales); if (has_zps) { - value_ptr scales = in_vals[1], zps = in_vals[2]; - const auto &scale_dims = ltw(scales->get_logical_tensor()).vdims(); - const auto &zp_dims = ltw(zps->get_logical_tensor()).vdims(); - for (size_t idx = 0; idx < scale_dims.size(); ++idx) { - VCHECK_INVALID_ARGUMENT((scale_dims[idx] == zp_dims[idx]), - "scale and zero point tensors should have the same shape"); - } - + value_ptr zps = in_vals[2]; const int64_t zps_data_type = zps->get_logical_tensor().data_type; op_ptr sub_zps = std::make_shared(op_kind::dnnl_sub_zps); sub_zps->connect_input(1, zps); @@ -632,6 +625,14 @@ static status_t dynamic_dequant_handler( sub_zps->set_attr(op_attr::qtype, qtype); sub_zps->set_attr(op_attr::data_type, zps_data_type); if (is_group_quantization) { + value_ptr scales = in_vals[1]; + const auto &scale_dims = ltw(scales->get_logical_tensor()).vdims(); + const auto &zp_dims = ltw(zps->get_logical_tensor()).vdims(); + for (size_t idx = 0; idx < scale_dims.size(); ++idx) { + VCHECK_INVALID_ARGUMENT((scale_dims[idx] == zp_dims[idx]), + "scale and zero point tensors should have the same " + "shape"); + } const auto &group_shape = cur_op->get_attr>( op_attr::group_shape); sub_zps->set_attr>( diff --git a/src/graph/interface/op_def_constraint.cpp b/src/graph/interface/op_def_constraint.cpp index 7efd6213561..cd5c6db1555 100644 --- a/src/graph/interface/op_def_constraint.cpp +++ b/src/graph/interface/op_def_constraint.cpp @@ -328,11 +328,27 @@ bool check_dyn_quant_dequant_scales_zps(const op_t *n) { // in case of not setting value for zps if (sz_zps == DNNL_GRAPH_UNKNOWN_DIM) { return true; } - VCHECK_SHAPE_INFER((sz_scales == sz_zps), - "%s, scales and zps should keep same. given scale " - "size: %d, given zp size: %d.", - op_t::kind2str(n->get_kind()).c_str(), - static_cast(sz_scales), static_cast(sz_zps)); + if (qtype == "per_group") { + const auto &ndims + = n->get_input_value(1)->get_logical_tensor().ndims; + const auto &scale_dims + = n->get_input_value(1)->get_logical_tensor().dims; + const auto &zp_dims + = n->get_input_value(2)->get_logical_tensor().dims; + VCHECK_SHAPE_INFER( + (std::equal(scale_dims, scale_dims + ndims, zp_dims)), + "%s, scales and zps should keep same for group quant", + op_t::kind2str(n->get_kind()).c_str()); + } + + if (qtype == "per_channel") { + VCHECK_SHAPE_INFER((sz_zps == 1 || sz_scales == sz_zps), + "%s, zps should be 1 or equals to scales size for " + "per_channel policy, given zps size: %d and scales size: " + "%d", + op_t::kind2str(n->get_kind()).c_str(), + static_cast(sz_zps), static_cast(sz_scales)); + } if (qtype == "per_tensor") { VCHECK_SHAPE_INFER((sz_zps == 1),