-
-
Notifications
You must be signed in to change notification settings - Fork 10.4k
Support FlashAttention Backend for Hybrid SSM Models #23299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to add support for the FlashAttention backend in hybrid SSM models by transforming the KV cache layout. The core logic in _update_hybrid_attention_mamba_layout
uses as_strided_
to change the tensor layout from (2, num_blocks, ...)
to (num_blocks, 2, ...)
. However, the current implementation of this transformation appears to be incorrect and will likely lead to bugs. I've provided a critical review comment with a suggested fix to correctly perform the layout change.
kv_cache.as_strided_(size=kv_cache.shape, | ||
stride=(hidden_size, 2 * hidden_size, | ||
*kv_cache.stride()[2:])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of as_strided_
here to change the tensor layout from (2, num_blocks, ...)
to (num_blocks, 2, ...)
appears to be incorrect for two main reasons:
- Incorrect
size
: Thesize
argument is set tokv_cache.shape
, which means the tensor's shape is not actually changed. Downstream consumers of this cache (like Mamba layers) will still see the original(2, num_blocks, ...)
shape, which defeats the purpose of this function. - Incorrect
stride
: The stride(hidden_size, 2 * hidden_size, ...)
does not correspond to a simple transpose of a contiguous tensor. For a contiguous tensor of shape(2, num_blocks, hidden_size)
, a transpose would result in strides of(hidden_size, num_blocks * hidden_size)
. The current implementation will lead to incorrect memory access patterns unlessnum_blocks
is 2, a case which is explicitly excluded by the assertion on line 3134.
A more correct and readable way to achieve an in-place transpose is to use transpose_
.
kv_cache.transpose_(0, 1)
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
This pull request has merge conflicts that must be resolved before it can be |
"VLLM_ATTENTION_BACKEND=FLASHINFER") | ||
kv_cache = kv_caches[layer_name] | ||
if (isinstance(kv_cache_spec, AttentionSpec) | ||
and kv_cache.shape[0] == 2): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great!
Suggestion: Instead of inferring attention backend from tensor shapes (shape[0] == 2 + assert), could we maybe use the available metadata more directly?
Two options:
Use group.metadata_builder
(Mamba1AttentionMetadataBuilder/FlashInfer/etc.)
Or
Initialize FullAttentionSpec
(and the rest) with attn_module.attn_backend
or attn_module.backend
similar to how we did it for mamba spec.
I think this approach would be more readable and explicit about what's happening. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think builder can make the code clean because it doesn't have an interface for whether it is (2, num_blocks) / (num_blocks_2). And even if we decide it during KVCacheConfig creation, we still meet the same problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we pass for example attn_module.backend
to FullAttentionSpec
? From what I saw it is set to Flashinfer/Flash_Attn_V1, etc. isn't that enough to understand if the shapes are in the order in which you are expecting them?
You will be able to make this condition kv_cache_spec.backend == FLASH_ATTN_V1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But what if a new attention backend is added? I want it to be more automated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, makes sense, let’s stick with the current approach then. Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@heheda12345 can you resolve conflicts and then we enable auto-merge? |
Signed-off-by: Chen Zhang <[email protected]>
@tdoublep I resolved the merge conflict and enabled auto-merge |
V1 test failure looks like it might be related to these changes? |
) Signed-off-by: Chen Zhang <[email protected]> Signed-off-by: tc-mb <[email protected]>
) Signed-off-by: Chen Zhang <[email protected]>
) Signed-off-by: Chen Zhang <[email protected]> Signed-off-by: Xiao Yu <[email protected]>
) Signed-off-by: Chen Zhang <[email protected]>
) Signed-off-by: Chen Zhang <[email protected]>
) Signed-off-by: Chen Zhang <[email protected]>
Purpose
As mentioned in #20016, v1 hybrid ssm requires the layout of attention to be (num_blocks, 2, hidden_size) so that the blocks can be shared between attention layers and mamba layers. This PR supports (2, num_blocks, hidden_size) layout by changing the tensor stride of (2, num_blocks) dimensions to (hidden_size, 2*hidden_size)
Test Plan
Run unit tests. Should pass
Test Result
Let's wait for CI.
(Optional) Documentation Update
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.