Skip to content

Conversation

heheda12345
Copy link
Collaborator

@heheda12345 heheda12345 commented Aug 21, 2025

Purpose

As mentioned in #20016, v1 hybrid ssm requires the layout of attention to be (num_blocks, 2, hidden_size) so that the blocks can be shared between attention layers and mamba layers. This PR supports (2, num_blocks, hidden_size) layout by changing the tensor stride of (2, num_blocks) dimensions to (hidden_size, 2*hidden_size)

Test Plan

Run unit tests. Should pass

Test Result

Let's wait for CI.

(Optional) Documentation Update


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Signed-off-by: Chen Zhang <[email protected]>
@heheda12345 heheda12345 requested a review from tdoublep August 21, 2025 01:07
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to add support for the FlashAttention backend in hybrid SSM models by transforming the KV cache layout. The core logic in _update_hybrid_attention_mamba_layout uses as_strided_ to change the tensor layout from (2, num_blocks, ...) to (num_blocks, 2, ...). However, the current implementation of this transformation appears to be incorrect and will likely lead to bugs. I've provided a critical review comment with a suggested fix to correctly perform the layout change.

Comment on lines +3139 to +3141
kv_cache.as_strided_(size=kv_cache.shape,
stride=(hidden_size, 2 * hidden_size,
*kv_cache.stride()[2:]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The use of as_strided_ here to change the tensor layout from (2, num_blocks, ...) to (num_blocks, 2, ...) appears to be incorrect for two main reasons:

  1. Incorrect size: The size argument is set to kv_cache.shape, which means the tensor's shape is not actually changed. Downstream consumers of this cache (like Mamba layers) will still see the original (2, num_blocks, ...) shape, which defeats the purpose of this function.
  2. Incorrect stride: The stride (hidden_size, 2 * hidden_size, ...) does not correspond to a simple transpose of a contiguous tensor. For a contiguous tensor of shape (2, num_blocks, hidden_size), a transpose would result in strides of (hidden_size, num_blocks * hidden_size). The current implementation will lead to incorrect memory access patterns unless num_blocks is 2, a case which is explicitly excluded by the assertion on line 3134.

A more correct and readable way to achieve an in-place transpose is to use transpose_.

                    kv_cache.transpose_(0, 1)

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link

mergify bot commented Aug 22, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @heheda12345.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 22, 2025
"VLLM_ATTENTION_BACKEND=FLASHINFER")
kv_cache = kv_caches[layer_name]
if (isinstance(kv_cache_spec, AttentionSpec)
and kv_cache.shape[0] == 2):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great!
Suggestion: Instead of inferring attention backend from tensor shapes (shape[0] == 2 + assert), could we maybe use the available metadata more directly?
Two options:

Use group.metadata_builder (Mamba1AttentionMetadataBuilder/FlashInfer/etc.)
Or
Initialize FullAttentionSpec (and the rest) with attn_module.attn_backend or attn_module.backend
similar to how we did it for mamba spec.

I think this approach would be more readable and explicit about what's happening. What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think builder can make the code clean because it doesn't have an interface for whether it is (2, num_blocks) / (num_blocks_2). And even if we decide it during KVCacheConfig creation, we still meet the same problem.

Copy link
Contributor

@Josephasafg Josephasafg Aug 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we pass for example attn_module.backend to FullAttentionSpec? From what I saw it is set to Flashinfer/Flash_Attn_V1, etc. isn't that enough to understand if the shapes are in the order in which you are expecting them?
You will be able to make this condition kv_cache_spec.backend == FLASH_ATTN_V1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But what if a new attention backend is added? I want it to be more automated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, makes sense, let’s stick with the current approach then. Thanks

Copy link
Member

@tdoublep tdoublep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a really elegant solution.

Had to spend a bit of time thinking this through carefully, and made a figure to convince myself I properly understood it. I've attached that below in case it's helpful for anyone else.

image

@tdoublep
Copy link
Member

@heheda12345 can you resolve conflicts and then we enable auto-merge?

@tdoublep tdoublep added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 25, 2025
@heheda12345 heheda12345 enabled auto-merge (squash) August 25, 2025 21:31
@mergify mergify bot removed the needs-rebase label Aug 25, 2025
@heheda12345
Copy link
Collaborator Author

@tdoublep I resolved the merge conflict and enabled auto-merge

@tdoublep
Copy link
Member

V1 test failure looks like it might be related to these changes?

@heheda12345 heheda12345 merged commit 2b4fc9b into vllm-project:main Aug 26, 2025
35 of 36 checks passed
@heheda12345 heheda12345 deleted the hybrid_fa3 branch August 26, 2025 12:41
tc-mb pushed a commit to tc-mb/vllm that referenced this pull request Aug 27, 2025
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Sep 3, 2025
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants