Skip to content

Commit

Permalink
graph: backend: dnnl: fix shape check for per-channel dynamic quant
Browse files Browse the repository at this point in the history
  • Loading branch information
wzt1997 committed Dec 27, 2024
1 parent 4de2319 commit d29d8d3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 13 deletions.
17 changes: 9 additions & 8 deletions src/graph/backend/dnnl/passes/lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,14 +616,7 @@ static status_t dynamic_dequant_handler(
rewriter.to_insert(mul_scales);

if (has_zps) {
value_ptr scales = in_vals[1], zps = in_vals[2];
const auto &scale_dims = ltw(scales->get_logical_tensor()).vdims();
const auto &zp_dims = ltw(zps->get_logical_tensor()).vdims();
for (size_t idx = 0; idx < scale_dims.size(); ++idx) {
VCHECK_INVALID_ARGUMENT((scale_dims[idx] == zp_dims[idx]),
"scale and zero point tensors should have the same shape");
}

value_ptr zps = in_vals[2];
const int64_t zps_data_type = zps->get_logical_tensor().data_type;
op_ptr sub_zps = std::make_shared<op_t>(op_kind::dnnl_sub_zps);
sub_zps->connect_input(1, zps);
Expand All @@ -632,6 +625,14 @@ static status_t dynamic_dequant_handler(
sub_zps->set_attr<std::string>(op_attr::qtype, qtype);
sub_zps->set_attr<int64_t>(op_attr::data_type, zps_data_type);
if (is_group_quantization) {
value_ptr scales = in_vals[1];
const auto &scale_dims = ltw(scales->get_logical_tensor()).vdims();
const auto &zp_dims = ltw(zps->get_logical_tensor()).vdims();
for (size_t idx = 0; idx < scale_dims.size(); ++idx) {
VCHECK_INVALID_ARGUMENT((scale_dims[idx] == zp_dims[idx]),
"scale and zero point tensors should have the same "
"shape");
}
const auto &group_shape = cur_op->get_attr<std::vector<int64_t>>(
op_attr::group_shape);
sub_zps->set_attr<std::vector<int64_t>>(
Expand Down
26 changes: 21 additions & 5 deletions src/graph/interface/op_def_constraint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,11 +328,27 @@ bool check_dyn_quant_dequant_scales_zps(const op_t *n) {
// in case of not setting value for zps
if (sz_zps == DNNL_GRAPH_UNKNOWN_DIM) { return true; }

VCHECK_SHAPE_INFER((sz_scales == sz_zps),
"%s, scales and zps should keep same. given scale "
"size: %d, given zp size: %d.",
op_t::kind2str(n->get_kind()).c_str(),
static_cast<int>(sz_scales), static_cast<int>(sz_zps));
if (qtype == "per_group") {
const auto &ndims
= n->get_input_value(1)->get_logical_tensor().ndims;
const auto &scale_dims
= n->get_input_value(1)->get_logical_tensor().dims;
const auto &zp_dims
= n->get_input_value(2)->get_logical_tensor().dims;
VCHECK_SHAPE_INFER(
(std::equal(scale_dims, scale_dims + ndims, zp_dims)),
"%s, scales and zps should keep same for group quant",
op_t::kind2str(n->get_kind()).c_str());
}

if (qtype == "per_channel") {
VCHECK_SHAPE_INFER((sz_zps == 1 || sz_scales == sz_zps),
"%s, zps should be 1 or equals to scales size for "
"per_channel policy, given zps size: %d and scales size: "
"%d",
op_t::kind2str(n->get_kind()).c_str(),
static_cast<int>(sz_zps), static_cast<int>(sz_scales));
}

if (qtype == "per_tensor") {
VCHECK_SHAPE_INFER((sz_zps == 1),
Expand Down

0 comments on commit d29d8d3

Please sign in to comment.