-
-
Notifications
You must be signed in to change notification settings - Fork 10.2k
allow calc_kv_scales #23906
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
allow calc_kv_scales #23906
Conversation
Summary: 1. The self.calc_kv_scales() should be invoked without checking `attn_metadata` 2. `attn_metadata` is avail only when full graph mode of cudagraph. If user did not use it, there is an error when checking `attn_metadata.enable_kv_scales_calculation` This diff should fix the above problem. But we can not use torch.compile when we set `calculate_kv_scales=True`, it will complain using .item() in `def calc_kv_scales()` Differential Revision: D81300417
This pull request was exported from Phabricator. Differential Revision: D81300417 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request fixes a crash that occurs when calculate_kv_scales
is enabled but the code is not running in cudagraph full graph mode. The fix removes a dependency on attn_metadata
, which can be None
in this scenario. While this correctly addresses the crash, it introduces a potential new issue: calc_kv_scales
is called unconditionally, but it will fail if key
or value
tensors are None
. The existence of checks for key is not None
and value is not None
later in the forward
method suggests this is a valid possibility. I've added a review comment to guard the call to calc_kv_scales
to prevent this potential crash.
attn_metadata = get_forward_context().attn_metadata | ||
if attn_metadata.enable_kv_scales_calculation: | ||
self.calc_kv_scales(query, key, value) | ||
self.calc_kv_scales(query, key, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call to self.calc_kv_scales
here could lead to a TypeError
if key
or value
is None
, as torch.abs(None)
would be executed. Later in this method (lines 260-263), there are checks for key is not None
and value is not None
, which implies they can indeed be None
. To prevent a potential crash, it's crucial to ensure key
and value
are not None
before calling calc_kv_scales
.
self.calc_kv_scales(query, key, value) | |
if key is not None and value is not None: | |
self.calc_kv_scales(query, key, value) |
@heheda12345 Do you mind reviewing this as I saw you touched it in #12536 Thanks! |
@mgoin Can you help to take a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we want to always enable scale calculation if not necessary.
Summary:
When running the gpt-oss, I found that there is a bug when enabling
calculate_kv_scales
:attn_metadata
attn_metadata
is avail only when full graph mode of cudagraph. If a user does not use it, there is an error(NoneType) when checkingattn_metadata.enable_kv_scales_calculation
This PR should fix the above problem.
But we can not use torch.compile when we set
calculate_kv_scales=True
, it will complain using .item() indef calc_kv_scales()
Differential Revision: D81300417