Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

graph: backend: dnnl: support permute for scale and zps inputs #2291

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ status_t sdp_primitive_config_t::initial_check(
}
}
if (op_kind != graph::op_kind::MatMul) continue;
// TODO(zhitao): execute the reorder for scale and zps mannually if the
// transpose attribute is specified as true.
if (cur_op->has_attr(op_attr::transpose_b)
&& cur_op->get_attr<bool>(op_attr::transpose_b))
return status::unimplemented;
auto post_op = get_post_op(cur_op);
if (post_op && mm1_post_op_kind.count(post_op->get_kind())) {
mm1 = cur_op;
Expand Down
3 changes: 0 additions & 3 deletions src/graph/backend/dnnl/passes/insert_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,6 @@ status_t insert_permute_for_dynamic_mul_scale_sub_zp(
std::swap(group_shape[ndims - 1], group_shape[ndims - 2]);
cur_op->set_attr<std::vector<int64_t>>(
op_attr::group_shape, group_shape);
} else { // per-channel quantization
const auto axis = cur_op->get_attr<int64_t>(op_attr::axis);
cur_op->set_attr<int64_t>(op_attr::axis, (2 * ndims - 3) - axis);
}
}

Expand Down
17 changes: 9 additions & 8 deletions src/graph/backend/dnnl/passes/lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,14 +616,7 @@ static status_t dynamic_dequant_handler(
rewriter.to_insert(mul_scales);

if (has_zps) {
value_ptr scales = in_vals[1], zps = in_vals[2];
const auto &scale_dims = ltw(scales->get_logical_tensor()).vdims();
const auto &zp_dims = ltw(zps->get_logical_tensor()).vdims();
for (size_t idx = 0; idx < scale_dims.size(); ++idx) {
VCHECK_INVALID_ARGUMENT((scale_dims[idx] == zp_dims[idx]),
"scale and zero point tensors should have the same shape");
}

value_ptr zps = in_vals[2];
const int64_t zps_data_type = zps->get_logical_tensor().data_type;
op_ptr sub_zps = std::make_shared<op_t>(op_kind::dnnl_sub_zps);
sub_zps->connect_input(1, zps);
Expand All @@ -632,6 +625,14 @@ static status_t dynamic_dequant_handler(
sub_zps->set_attr<std::string>(op_attr::qtype, qtype);
sub_zps->set_attr<int64_t>(op_attr::data_type, zps_data_type);
if (is_group_quantization) {
value_ptr scales = in_vals[1];
const auto &scale_dims = ltw(scales->get_logical_tensor()).vdims();
const auto &zp_dims = ltw(zps->get_logical_tensor()).vdims();
for (size_t idx = 0; idx < scale_dims.size(); ++idx) {
VCHECK_INVALID_ARGUMENT((scale_dims[idx] == zp_dims[idx]),
"scale and zero point tensors should have the same "
"shape");
}
const auto &group_shape = cur_op->get_attr<std::vector<int64_t>>(
op_attr::group_shape);
sub_zps->set_attr<std::vector<int64_t>>(
Expand Down
20 changes: 20 additions & 0 deletions src/graph/backend/dnnl/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,26 @@ void subgraph_rewriter_t::insert_op_before(const op_ptr &inserted_op,
auto in_dtype = in_val->get_logical_tensor().data_type;
new_val->set_data_type(in_dtype);

if (inserted_op->get_kind() == op_kind::dnnl_permute
&& (base_op->get_kind() == op_kind::dnnl_mul_scales
|| base_op->get_kind() == op_kind::dnnl_sub_zps)) {
wzt1997 marked this conversation as resolved.
Show resolved Hide resolved
// Only abx tag is respected for scale and zps inputs, should set
// strides explicitly and execute reorder.

dnnl::memory::desc in_md
= make_dnnl_memory_desc(in_val->get_logical_tensor());
const auto &perm = inserted_op->get_attr<std::vector<int64_t>>(
op_attr::permutation);
std::vector<int> int_perm(perm.size(), -1);
for (size_t i = 0; i < perm.size(); i++) {
int_perm[i] = static_cast<int>(perm[i]);
}
dnnl::memory::desc out_md = in_md.permute_axes(int_perm);
const auto &dims = out_md.get_dims();
// set the strides with abx tag.
new_val->set_strides(get_dense_strides(dims));
}

if (k == std::numeric_limits<size_t>::max()) {
k = inserted_op->num_outputs();
}
Expand Down
12 changes: 7 additions & 5 deletions src/graph/interface/op_def_constraint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,11 +328,13 @@ bool check_dyn_quant_dequant_scales_zps(const op_t *n) {
// in case of not setting value for zps
if (sz_zps == DNNL_GRAPH_UNKNOWN_DIM) { return true; }

VCHECK_SHAPE_INFER((sz_scales == sz_zps),
"%s, scales and zps should keep same. given scale "
"size: %d, given zp size: %d.",
op_t::kind2str(n->get_kind()).c_str(),
static_cast<int>(sz_scales), static_cast<int>(sz_zps));
if (qtype == "per_group") {
VCHECK_SHAPE_INFER((sz_scales == sz_zps),
"%s, scales and zps should keep same. given scale "
"size: %d, given zp size: %d.",
op_t::kind2str(n->get_kind()).c_str(),
static_cast<int>(sz_scales), static_cast<int>(sz_zps));
}

if (qtype == "per_tensor") {
VCHECK_SHAPE_INFER((sz_zps == 1),
Expand Down
3 changes: 3 additions & 0 deletions tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
--reset --dt=f32,bf16,f16 --in-shapes=0:32x16x128x64+1:32x16x128x64+5:32x16x128x128+8:32x16x128x64 --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --dt=f32,bf16,f16 --in-shapes=0:acbd+1:acbd+8:acbd --case=complex_fusion/mha/sdpa-plain-simplified-f16.json
--reset --dt=f32,bf16,f16 --in-shapes=3:384,3:384x384,3:1x16x384x384 --case=complex_fusion/mha/sdpa-plain-scale-by-mul-f16.json
--reset --op-attrs=34107656704:group_shape:1x1x1x32+34107654464:transpose_b:1 --in-shapes=0:1x32x32x128+1:1x32x32x4+2:1x32x32x4 --case=complex_fusion/mha/sdpa-compressed-k-int8-gs32.json
--reset --op-attrs=34107656704:qtype:per_channel*axis:3 --in-shapes=1:32+2:1 --case=complex_fusion/mha/sdpa-compressed-k-int8-gs32.json
--reset --op-attrs=34107656704:qtype:per_channel*axis:2+34107654464:transpose_b:1 --in-shapes=0:1x32x32x128+1:32+2:1 --case=complex_fusion/mha/sdpa-compressed-k-int8-gs32.json

# Re-written int8 graphs
--reset --in-shapes=5:4x16x32x256+4:4x16x256x33+0:4x16x33x256+1:4x1x1x33+3:4x1x32x33 --case=complex_fusion/mha/MHA-GPT-inf-int8-bs1.json
Expand Down
Loading