Skip to content

Conversation

frank-wei
Copy link
Contributor

@frank-wei frank-wei commented Aug 29, 2025

Summary:
When running the gpt-oss, I found that there is a bug when enabling calculate_kv_scales:

  1. The self.calc_kv_scales() should be invoked without checking attn_metadata
  2. attn_metadata is avail only when full graph mode of cudagraph. If a user does not use it, there is an error(NoneType) when checking attn_metadata.enable_kv_scales_calculation
    This PR should fix the above problem.

But we can not use torch.compile when we set calculate_kv_scales=True, it will complain using .item() in def calc_kv_scales()

Differential Revision: D81300417

Summary:
1. The self.calc_kv_scales() should be invoked without checking `attn_metadata`
2. `attn_metadata` is avail only when full graph mode of cudagraph. If user did not use it, there is an error when checking `attn_metadata.enable_kv_scales_calculation`
This diff should fix the above problem.

But we can not use torch.compile when we set `calculate_kv_scales=True`, it will complain using .item() in `def calc_kv_scales()`

Differential Revision: D81300417
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D81300417

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes a crash that occurs when calculate_kv_scales is enabled but the code is not running in cudagraph full graph mode. The fix removes a dependency on attn_metadata, which can be None in this scenario. While this correctly addresses the crash, it introduces a potential new issue: calc_kv_scales is called unconditionally, but it will fail if key or value tensors are None. The existence of checks for key is not None and value is not None later in the forward method suggests this is a valid possibility. I've added a review comment to guard the call to calc_kv_scales to prevent this potential crash.

attn_metadata = get_forward_context().attn_metadata
if attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(query, key, value)
self.calc_kv_scales(query, key, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The call to self.calc_kv_scales here could lead to a TypeError if key or value is None, as torch.abs(None) would be executed. Later in this method (lines 260-263), there are checks for key is not None and value is not None, which implies they can indeed be None. To prevent a potential crash, it's crucial to ensure key and value are not None before calling calc_kv_scales.

Suggested change
self.calc_kv_scales(query, key, value)
if key is not None and value is not None:
self.calc_kv_scales(query, key, value)

@22quinn 22quinn requested a review from heheda12345 September 5, 2025 06:51
@22quinn
Copy link
Collaborator

22quinn commented Sep 5, 2025

@heheda12345 Do you mind reviewing this as I saw you touched it in #12536 Thanks!

@heheda12345
Copy link
Collaborator

@mgoin Can you help to take a look?

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want to always enable scale calculation if not necessary.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants