Skip to content

Conversation

titaiwangms
Copy link
Contributor

  1. Support enable_gqa
  2. Align PyTorch setting to unsupport Q, K, and V when they are not 4D: https://github.com/pytorch/pytorch/blob/62843c14bbf694f5722fd6e1075da4792507fe42/torch/onnx/_internal/exporter/_torchlib/ops/nn.py#L131-L133

NOTE: torch.nn.functional.scaled_dot_product_attention actually supports 3D, and even Q-3D with K and V - 4D in op tests.

matcher=lambda sample: len(sample.input.shape) != 4
or len(sample.args[0].shape) != 4
or len(sample.args[1].shape) != 4,
reason="torch sdpa is expected to pass in 4d q, k, and v.",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby @xadupre Let me know what you think on whether we should support only 4d QKV, or we should fully support whatever torch sdpa supports. Right now, it seems like QKV can have 3d or 4d or even q 3d and kv 4d in torch sdpa.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on the ATen op? Does the nn function do preprocessing on the inputs before sending them to the kernel? We just need to support whatever the kernel supports

Copy link

codecov bot commented Sep 11, 2025

Codecov Report

❌ Patch coverage is 14.28571% with 24 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.32%. Comparing base (647b22a) to head (e661531).
⚠️ Report is 6 commits behind head on main.

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/nn.py 14.28% 22 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2558      +/-   ##
==========================================
- Coverage   70.34%   70.32%   -0.03%     
==========================================
  Files         218      222       +4     
  Lines       26430    26645     +215     
  Branches     2647     2663      +16     
==========================================
+ Hits        18593    18738     +145     
- Misses       6934     6991      +57     
- Partials      903      916      +13     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Sep 11, 2025
@justinchuby
Copy link
Collaborator

Could you also add these few lines

if dropout_p > 0.0:
    attn_weight, _ = op.Dropout(attn_weight, dropout_p)

as a micro optimization? Or we can do that separately

@titaiwangms
Copy link
Contributor Author

Could you also add these few lines

if dropout_p > 0.0:
    attn_weight, _ = op.Dropout(attn_weight, dropout_p)

as a micro optimization? Or we can do that separately

it's already there:

attn_weight, _ = op.Dropout(attn_weight, dropout_p)

@titaiwangms titaiwangms enabled auto-merge (squash) September 12, 2025 00:37
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Sep 12, 2025
@titaiwangms titaiwangms merged commit 8ed3521 into microsoft:main Sep 12, 2025
32 checks passed
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: torchlib Related to the torch/aten function lib in development
Projects
Development

Successfully merging this pull request may close these issues.

2 participants