Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

graph: backend: dnnl: support permute for scale and zps inputs #2291

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

wzt1997
Copy link
Contributor

@wzt1997 wzt1997 commented Dec 19, 2024

Description

We recently supported the compressed SDPA patterns which incorporate scales and zero points as inputs. Also, we aim to support the problem with a Key shape of ( N, H, S, D ) with transpose_b=true for the QK MatMul. For such case with transposed MatMul, the graph library adds permute ops prior to the pattern input, including the scale and zp tensors. To be more specific, for the following woq MatMul with tranpose_b=true:
image

After the graph compilation, we will have the graph as follows. The permute ops will not touch the physical memories but only change the mds.
image

However, we faced a issue that DNNL backend always regards scale and zp tensors as abx format, leading to failing results, hence we need to set the tag explicitly as abx and execute extra reorders. After the change, the graph after compilation is like:
image

TODO

  • [WIP] Manually execute the reorder for sdpa primitive cases.

@wzt1997 wzt1997 added the component:graph-api Codeowner: @oneapi-src/onednn-graph label Dec 19, 2024
@wzt1997 wzt1997 self-assigned this Dec 19, 2024
@wzt1997 wzt1997 requested review from a team as code owners December 19, 2024 09:03
@github-actions github-actions bot added the component:tests Codeowner: @oneapi-src/onednn-arch label Dec 19, 2024
@wzt1997 wzt1997 force-pushed the zhitao/support-scale-permute branch from 953b5a6 to ce87611 Compare December 19, 2024 09:13
@wzt1997 wzt1997 force-pushed the zhitao/support-scale-permute branch 2 times, most recently from cb5d6c4 to 3a713a8 Compare December 20, 2024 01:33
@wzt1997
Copy link
Contributor Author

wzt1997 commented Dec 22, 2024

make test
enable benchdnn_nightly
disable benchdnn_all
enable benchdnn_graph

@wzt1997 wzt1997 force-pushed the zhitao/support-scale-permute branch from 3a713a8 to 1b80646 Compare December 23, 2024 07:13
@wzt1997
Copy link
Contributor Author

wzt1997 commented Dec 24, 2024

make test
enable benchdnn_nightly
disable benchdnn_all
enable benchdnn_graph

@wzt1997 wzt1997 force-pushed the zhitao/support-scale-permute branch from 1b80646 to b1eed4d Compare December 25, 2024 08:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component:graph-api Codeowner: @oneapi-src/onednn-graph component:tests Codeowner: @oneapi-src/onednn-arch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants