-
Notifications
You must be signed in to change notification settings - Fork 82
Support enable_gqa
and only support 4D Q, K, and V
#2558
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support enable_gqa
and only support 4D Q, K, and V
#2558
Conversation
matcher=lambda sample: len(sample.input.shape) != 4 | ||
or len(sample.args[0].shape) != 4 | ||
or len(sample.args[1].shape) != 4, | ||
reason="torch sdpa is expected to pass in 4d q, k, and v.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@justinchuby @xadupre Let me know what you think on whether we should support only 4d QKV, or we should fully support whatever torch sdpa supports. Right now, it seems like QKV can have 3d or 4d or even q 3d and kv 4d in torch sdpa.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Depending on the ATen op? Does the nn function do preprocessing on the inputs before sending them to the kernel? We just need to support whatever the kernel supports
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2558 +/- ##
==========================================
- Coverage 70.34% 70.32% -0.03%
==========================================
Files 218 222 +4
Lines 26430 26645 +215
Branches 2647 2663 +16
==========================================
+ Hits 18593 18738 +145
- Misses 6934 6991 +57
- Partials 903 916 +13 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Could you also add these few lines
as a micro optimization? Or we can do that separately |
it's already there:
|
Fixes #162258 Related to microsoft/onnxscript#2558 Pull Request resolved: #162771 Approved by: https://github.com/justinchuby
Fixes pytorch#162258 Related to microsoft/onnxscript#2558 Pull Request resolved: pytorch#162771 Approved by: https://github.com/justinchuby
Fixes pytorch#162258 Related to microsoft/onnxscript#2558 Pull Request resolved: pytorch#162771 Approved by: https://github.com/justinchuby
Fixes pytorch#162258 Related to microsoft/onnxscript#2558 Pull Request resolved: pytorch#162771 Approved by: https://github.com/justinchuby
Fixes pytorch#162258 Related to microsoft/onnxscript#2558 Pull Request resolved: pytorch#162771 Approved by: https://github.com/justinchuby
enable_gqa
NOTE: torch.nn.functional.scaled_dot_product_attention actually supports 3D, and even Q-3D with K and V - 4D in op tests.