Skip to content

Conversation

wwl2755
Copy link
Contributor

@wwl2755 wwl2755 commented Sep 5, 2025

Purpose

  1. Enable FA3 in VIT. The original FA3 is not supported because head_dim has to be a multiple of 32 (in vllm's build). However, in the upstream FA, it could theoratically support any head_dim which is a multiple of 8. (https://github.com/Dao-AILab/flash-attention/blob/v2.8.3/csrc/flash_attn/flash_api.cpp#L605).

vLLM's native FA does not support it because it forbids FLASHATTENTION_DISABLE_UNEVEN_K to padding. (http://github.com/vllm-project/flash-attention/blob/v2.6.2/setup.py#L223)

So, we use vllm's native flash attention if appliable and upstream FA as a fallback.

Previous attempt and revert: #12435, #12355, #12445

  1. Refactor get_vit_attn_backend() so that it could decide the attention backend tyoe internally.

Related to #23880, #23884

cc: @ywang96 @DarkLight1337 @Isotr0py

Test

vllm serve llava-hf/llava-onevision-qwen2-7b-ov-hf --trust-remote-code

vllm bench serve  \
	--backend openai-chat   \
	--model llava-hf/llava-onevision-qwen2-7b-ov-hf   \
	--endpoint /v1/chat/completions   \
	--dataset-name random-mm   \
	--num-prompts 100   \
	--random-prefix-len 0   \
	--random-input-len 0  \
	--random-output-len 1   \
	--random-range-ratio 0   \
	--random-mm-base-items-per-request 1   \
	--random-mm-limit-mm-per-prompt '{"image": 1, "video": 0}'   \
	--random-mm-bucket-config '{(224, 224, 1): 1.0}'   \
	--request-rate 10   \
	--ignore-eos   \
	--seed 42 \
	--endpoint-type openai-chat

whose head_dim is 72, not a multiple of 32.

vllm serve OpenGVLab/InternVL3_5-4B --trust-remote-code

vllm bench serve  \
	--backend openai-chat   \
	--model llava-hf/llava-onevision-qwen2-7b-ov-hf   \
        ...

whose head_dim is 64, a multiple of 32.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables FlashAttention 3 in ViT, which is a great enhancement. The implementation correctly falls back to the upstream FlashAttention library when the head dimension is not a multiple of 32, which is necessary for certain models. My review includes a suggestion to optimize the import of the FlashAttention function to avoid performance overhead in the forward pass.

Comment on lines 402 to 437
if self.head_size % 32 != 0:
# import from upstream flash_attn
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func

cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
step=q_len,
dtype=torch.int32,
device=query.device)
cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
step=kv_len,
dtype=torch.int32,
device=key.device)

out = flash_attn_varlen_func(
query.flatten(0, 1),
key.flatten(0, 1),
value.flatten(0, 1),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=kv_len,
softmax_scale=self.scale,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Placing imports inside a method on a hot path, like forward, can introduce significant performance overhead due to repeated module lookups. It's better to perform the import only once.

A good approach here would be to lazily initialize and cache the flash_attn_varlen_func on the first call. This avoids the import overhead on subsequent calls. While the ideal place for this logic is in the __init__ method, lazy initialization within forward is a great improvement that keeps the changes within this method.

            if not hasattr(self, "_flash_attn_varlen_func"):
                if self.head_size % 32 != 0:
                    # import from upstream flash_attn
                    from flash_attn import flash_attn_varlen_func
                else:
                    from vllm.vllm_flash_attn import flash_attn_varlen_func
                self._flash_attn_varlen_func = flash_attn_varlen_func

            cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
                                        step=q_len,
                                        dtype=torch.int32,
                                        device=query.device)
            cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
                                        step=kv_len,
                                        dtype=torch.int32,
                                        device=key.device)

            out = self._flash_attn_varlen_func(
                query.flatten(0, 1),
                key.flatten(0, 1),
                value.flatten(0, 1),
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=q_len,
                max_seqlen_k=kv_len,
                softmax_scale=self.scale,
            )

Comment on lines 402 to 404
if self.head_size % 32 != 0:
# import from upstream flash_attn
from flash_attn import flash_attn_varlen_func
Copy link
Member

@Isotr0py Isotr0py Sep 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I think get_attn_backend won't and shouldn't return FA backend if head_size incompatible. Can you check if this codepath is actually active?

Copy link
Contributor Author

@wwl2755 wwl2755 Sep 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. Codepath doesn't work in this commit. Thanks for pointing out!

In my local commit, I tested by hard-coding the supported_head_sizes() in https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/flash_attn.py#L50 and it works fine.

In this case, I think we may have different options to handle the cases where head_size % 32 != 0:

  1. Still use xformer. -> we just delete this part and leave for future refactor, or until vllm's native fork change its settings.
  2. Change the supported_head_sizes() to multiple of 8. -> This changes may affects some other places, not sure if it is good.
    2.1 refactor MHA to use get_vit_attn_backend -> vit_attn has fewer choices so it should be easier to handle than option 2

WDYT? And for those head_size % 32 ==0, I think it should be fine as along as we feel good with the precision.

Copy link
Member

@ywang96 ywang96 Sep 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this be an issue since upstream flash_attn is not a dependency of vLLM?

If so I think it's probably a better idea to check this at model init time within get_vit_attn_backend. I'm thinking:

  1. If a head size is supported by vLLM FA, use vLLM FA
  2. If a head size is not supported by vLLM FA but upstream is detected, use upstream FA.
  3. If a head size is not supported by vLLM FA but upstream FA is not installed, we simply fall back to xformers, but print something like "Using xformers for ViT attention backend. To use flash attention for ViT please install flash_attn"

and get_vit_attn_backend should be responsible for doing the decision making here and return the true attention backend to use. We can simply return another boolean to get_vit_attn_backend from differentiate upstream FA and vLLM FA if that's the concern.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_vit_attn_backend should be responsible for doing the decision making here

Totally agree with the fallback algorithm. It's pretty much what I plan as option 2.1. The current concern is that now MHA attention is still using get_attn_backend. I'm not sure whether there are any other scenario/corner cases (except VIT) would trigger this code path. If there isn't any, I can refactor the get_vit_attn_backend and use it in the MHA instead.

Copy link
Contributor Author

@wwl2755 wwl2755 Sep 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After I scan the referenced files of MHA, they all seems to be VIT related. So I think it should be fine changing to use get_vit_attn_backend there.

I will update the PR soon.

@DarkLight1337 DarkLight1337 added the multi-modality Related to multi-modality (#4194) label Sep 6, 2025
@wwl2755
Copy link
Contributor Author

wwl2755 commented Sep 6, 2025

Here is a detailed lm_eval result of OpenGVLab/InternVL3_5-4B (head_dim=64) and llava-hf/llava-onevision-qwen2-7b-ov-hf (hea_dim=72) : cc: @ywang96

OpenGVLab/InternVL3_5-4B

(head_dim is 64)

lm_eval --model vllm-vlm \
  --model_args pretrained=OpenGVLab/InternVL3_5-4B \
  --tasks mmmu_val \
  --num_fewshot 5 \
  --limit 250 \
  --trust_remote_code \
  --batch_size 'auto'  \
  --apply_chat_template

_Backend.FLASH_ATTN_VLLM_V1

Groups Version Filter n-shot Metric Value Stderr
mmmu_val 0 none acc 0.5078 ± 0.0161
- Art and Design 0 none acc 0.6333 ± 0.0419
- Business 0 none acc 0.4733 ± 0.0405
- Health and Medicine 0 none acc 0.5467 ± 0.0412
- Humanities and Social Science 0 none acc 0.7000 ± 0.0415
- Science 0 none acc 0.4067 ± 0.0404
- Tech and Engineering 0 none acc 0.3952 ± 0.0333

_Backend.XFORMERS

Groups Version Filter n-shot Metric Value Stderr
mmmu_val 0 none acc 0.5033 ± 0.0161
- Art and Design 0 none acc 0.6500 ± 0.0423
- Business 0 none acc 0.4533 ± 0.0406
- Health and Medicine 0 none acc 0.5467 ± 0.0412
- Humanities and Social Science 0 none acc 0.6917 ± 0.0418
- Science 0 none acc 0.4133 ± 0.0406
- Tech and Engineering 0 none acc 0.3810 ± 0.0330

llava-hf/llava-onevision-qwen2-7b-ov-hf

(head_dim is 72)

lm_eval --model vllm-vlm \
  --model_args pretrained=llava-hf/llava-onevision-qwen2-7b-ov-hf \
  --tasks mmmu_val \
  --num_fewshot 5 \
  --limit 10 \
  --trust_remote_code \
  --batch_size 'auto'  \
  --apply_chat_template > vllm.log 2>&1

_Backend.FLASH_ATTN_VLLM_V1 (from upstream flash_attn)

Groups Version Filter n-shot Metric Value Stderr
mmmu_val 0 none acc 0.4567 ± 0.0281
- Art and Design 0 none acc 0.6500 ± 0.0717
- Business 0 none acc 0.4600 ± 0.0733
- Health and Medicine 0 none acc 0.3800 ± 0.0715
- Humanities and Social Science 0 none acc 0.5750 ± 0.0692
- Science 0 none acc 0.3600 ± 0.0673
- Tech and Engineering 0 none acc 0.4000 ± 0.0599

_Backend.XFORMERS

Groups Version Filter n-shot Metric Value Stderr
mmmu_val 0 none acc 0.4567 ± 0.0281
- Art and Design 0 none acc 0.6500 ± 0.0717
- Business 0 none acc 0.4600 ± 0.0733
- Health and Medicine 0 none acc 0.3800 ± 0.0715
- Humanities and Social Science 0 none acc 0.5750 ± 0.0692
- Science 0 none acc 0.3600 ± 0.0673
- Tech and Engineering 0 none acc 0.4000 ± 0.0599

@wwl2755 wwl2755 requested a review from sighingnow as a code owner September 7, 2025 22:09
@mergify mergify bot added qwen Related to Qwen models rocm Related to AMD ROCm labels Sep 7, 2025
@wwl2755
Copy link
Contributor Author

wwl2755 commented Sep 7, 2025

@ywang96 @Isotr0py @DarkLight1337 Gently ping for a quick review if you have some time to see whether this refactor makes sense to you. I will work on changing the function interfaces in related tests at the same time.

@david6666666
Copy link
Contributor

Can you add a benchmark comparison?

@wwl2755
Copy link
Contributor Author

wwl2755 commented Sep 8, 2025

@david6666666 Here is the benchmark.

vllm serve Qwen/Qwen2-VL-7B-Instruct --trust-remote-code

vllm bench serve  \
	--backend openai-chat   \
	--model Qwen/Qwen2-VL-7B-Instruct   \
	--endpoint /v1/chat/completions   \
	--dataset-name random-mm   \
	--num-prompts 100   \
	--max-concurrency 10 \
	--random-prefix-len 0   \
	--random-input-len 1  \
	--random-output-len 1   \
	--random-range-ratio 0   \
	--random-mm-base-items-per-request 2 \
	--random-mm-limit-mm-per-prompt '{"image": 3, "video": 0}' \
	--random-mm-bucket-config '{(224, 224, 1): 1.0}' \
	--request-rate inf \
	--ignore-eos   \
	--seed 42 \
	--endpoint-type openai-chat
	
# _Backend.XFORMERS
============ Serving Benchmark Result ============
Successful requests:                     100       
Maximum request concurrency:             10        
Benchmark duration (s):                  2.42      
Total input tokens:                      100       
Total generated tokens:                  100       
Request throughput (req/s):              41.39     
Output token throughput (tok/s):         41.39     
Total Token throughput (tok/s):          82.77     
---------------Time to First Token----------------
Mean TTFT (ms):                          231.65    
Median TTFT (ms):                        211.86    
P99 TTFT (ms):                           428.26    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          0.00      
Median TPOT (ms):                        0.00      
P99 TPOT (ms):                           0.00      
---------------Inter-token Latency----------------
Mean ITL (ms):                           0.02      
Median ITL (ms):                         0.02      
P99 ITL (ms):                            0.04      
==================================================

# _Backend.FLASH_ATTN, use_upstream_fa: True
============ Serving Benchmark Result ============
Successful requests:                     100       
Maximum request concurrency:             10        
Benchmark duration (s):                  2.40      
Total input tokens:                      100       
Total generated tokens:                  100       
Request throughput (req/s):              41.63     
Output token throughput (tok/s):         41.63     
Total Token throughput (tok/s):          83.26     
---------------Time to First Token----------------
Mean TTFT (ms):                          231.86    
Median TTFT (ms):                        210.86    
P99 TTFT (ms):                           425.17    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          0.00      
Median TPOT (ms):                        0.00      
P99 TPOT (ms):                           0.00      
---------------Inter-token Latency----------------
Mean ITL (ms):                           0.02      
Median ITL (ms):                         0.02      
P99 ITL (ms):                            0.04      
==================================================

Comment on lines 211 to 226
if cls.has_device_capability(80):
if head_size % 32 == 0:
# Use vllm-flash-attn
return _Backend.FLASH_ATTN, False
if head_size % 32 != 0:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
# Use upstream FA
return _Backend.FLASH_ATTN, True
else:
# Fallback to XFORMERS
logger.warning_once(
"Using xformers for ViT attention backend. "
"To use flash attention for ViT"
"please install flash_attn")
return _Backend.XFORMERS, False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if cls.has_device_capability(80):
if head_size % 32 == 0:
# Use vllm-flash-attn
return _Backend.FLASH_ATTN, False
if head_size % 32 != 0:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
# Use upstream FA
return _Backend.FLASH_ATTN, True
else:
# Fallback to XFORMERS
logger.warning_once(
"Using xformers for ViT attention backend. "
"To use flash attention for ViT"
"please install flash_attn")
return _Backend.XFORMERS, False
if cls.has_device_capability(80):
is_default_fa_supported = is_attn_backend_supported(
FLASH_ATTN_V1, head_size, dtype,
allow_import_error=False)
is_upstream_fa_supported = is_flash_attn_2_available()
if is_default_fa_supported:
return _Backend.FLASH_ATTN, False
elif is_upstream_fa_supported:
return _Backend.FLASH_ATTN, True
return _Backend.XFORMERS, False

head_size % 32 == 0 maybe not a robust solution, because fork FA's support head_size might change in the future. I think we can use is_attn_backend_supported here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

@Isotr0py
Copy link
Member

Isotr0py commented Sep 8, 2025

_Backend.XFORMERS
Output token throughput (tok/s): 41.39
Total Token throughput (tok/s): 82.77
---------------Time to First Token----------------
Mean TTFT (ms): 231.65
Median TTFT (ms): 211.86
P99 TTFT (ms): 428.26

_Backend.FLASH_ATTN, use_upstream_fa: True
Output token throughput (tok/s): 41.63
Total Token throughput (tok/s): 83.26
---------------Time to First Token----------------
Mean TTFT (ms): 231.86
Median TTFT (ms): 210.86
P99 TTFT (ms): 425.17

Hmmm, seems the benefit from FA is still quite minor for ViT (this is same as my benchmark conclusion last time), I'm hesitant if it's worthwile to introduce extra maintenance from upstream FA... 🤔

@wwl2755
Copy link
Contributor Author

wwl2755 commented Sep 8, 2025

Hmmm, seems the benefit from FA is still quite minor for ViT (this is same as my benchmark conclusion last time), I'm hesitant if it's worthwile to introduce extra maintenance from upstream FA... 🤔

Yes, The improvement seems subtle. I'm testing some larger image to see if that makes a difference.

However, I think we still should refactor this because previous mha is using the codepath of get_attention_backend(). After refactoring the get_vit_attn_backend, it should be easier to consolidate those individual models like qwen2-vl.

introduce extra maintenance from upstream FA

If upstream FA breaks anything, I think we could easiy disable it? (by forcing the second return value of get_vit_attn_backend to False). And upstream FA would not take effect unless user manually install it.

@Isotr0py
Copy link
Member

Isotr0py commented Sep 8, 2025

However, I think we still should refactor this because previous mha is using the codepath of get_attention_backend(). After refactoring the get_vit_attn_backend, it should be easier to consolidate those individual models like qwen2-vl.

Hmmm, I think the exact blocker for Qwen2-VL-style ViT consolidation is how to unify attention_mask preparation with optimization. Whether to use get_attention_backend() or get_vit_attn_backend won't be a big issue IMO.

@wwl2755
Copy link
Contributor Author

wwl2755 commented Sep 8, 2025

Here is more test result. I'm using upstream as flash_attn==2.7.4.post1 released in Jan, 25 because of my low ubuntu version.

What we compared:

batch_size=1, 2 
seq_len=1024, 2048, 4096  
num_heads=16 
head_size=64, 80 
num_kv_heads=16
backend=FLASH_ATTN, FLASH_ATTN (upstream), TORCH_SDPA, XFORMERS 
================================================================================
BENCHMARK RESULTS
================================================================================

Successful runs: 42/48

Performance by Backend:
--------------------------------------------------
FLASH_ATTN     :     0.56ms avg,    6677821 tokens/sec avg
FLASH_ATTN (upstream):     0.65ms avg,    5713239 tokens/sec avg
TORCH_SDPA     :     0.62ms avg,    6953648 tokens/sec avg
XFORMERS       :     0.78ms avg,    4442269 tokens/sec avg

Detailed Results:
========================================================================================================================
Configuration: batch_size=1, seq_len=1024, num_heads=16, head_size=64, num_kv_heads=16
------------------------------------------------------------------------------------------------------------------------
  TORCH_SDPA                  |   0.14ms ± 0.00ms |  7236238 tok/s
  FLASH_ATTN                  |   0.20ms ± 0.00ms |  5018823 tok/s
  FLASH_ATTN      (upstream)  |   0.25ms ± 0.01ms |  4128957 tok/s
  XFORMERS                    |   0.40ms ± 0.02ms |  2592195 tok/s

------------------------------------------------------------------------------------------------------------------------
Configuration: batch_size=1, seq_len=1024, num_heads=16, head_size=80, num_kv_heads=16
------------------------------------------------------------------------------------------------------------------------
  TORCH_SDPA                  |   0.18ms ± 0.00ms |  5576228 tok/s
  FLASH_ATTN      (upstream)  |   0.27ms ± 0.00ms |  3762513 tok/s
  XFORMERS                    |   0.41ms ± 0.02ms |  2483506 tok/s

------------------------------------------------------------------------------------------------------------------------
Configuration: batch_size=1, seq_len=2048, num_heads=16, head_size=64, num_kv_heads=16
------------------------------------------------------------------------------------------------------------------------
  TORCH_SDPA                  |   0.27ms ± 0.00ms |  7582075 tok/s
  FLASH_ATTN                  |   0.34ms ± 0.00ms |  6101422 tok/s
  FLASH_ATTN      (upstream)  |   0.36ms ± 0.00ms |  5729846 tok/s
  XFORMERS                    |   0.49ms ± 0.00ms |  4143831 tok/s

------------------------------------------------------------------------------------------------------------------------
Configuration: batch_size=1, seq_len=2048, num_heads=16, head_size=80, num_kv_heads=16
------------------------------------------------------------------------------------------------------------------------
  TORCH_SDPA                  |   0.39ms ± 0.00ms |  5256108 tok/s
  FLASH_ATTN      (upstream)  |   0.44ms ± 0.00ms |  4660630 tok/s
  XFORMERS                    |   0.57ms ± 0.01ms |  3581487 tok/s

------------------------------------------------------------------------------------------------------------------------
Configuration: batch_size=1, seq_len=4096, num_heads=16, head_size=64, num_kv_heads=16
------------------------------------------------------------------------------------------------------------------------
  TORCH_SDPA                  |   0.70ms ± 0.00ms |  5888902 tok/s
  FLASH_ATTN      (upstream)  |   0.74ms ± 0.01ms |  5511516 tok/s
  FLASH_ATTN                  |   0.76ms ± 0.01ms |  5358508 tok/s
  XFORMERS                    |   0.87ms ± 0.00ms |  4725724 tok/s

------------------------------------------------------------------------------------------------------------------------
Configuration: batch_size=1, seq_len=4096, num_heads=16, head_size=80, num_kv_heads=16
------------------------------------------------------------------------------------------------------------------------
  FLASH_ATTN      (upstream)  |   1.00ms ± 0.00ms |  4088187 tok/s
  TORCH_SDPA                  |   1.06ms ± 0.00ms |  3857065 tok/s
  XFORMERS                    |   1.12ms ± 0.01ms |  3647293 tok/s

------------------------------------------------------------------------------------------------------------------------
Configuration: batch_size=2, seq_len=1024, num_heads=16, head_size=64, num_kv_heads=16
------------------------------------------------------------------------------------------------------------------------
  TORCH_SDPA                  |   0.18ms ± 0.00ms | 11490735 tok/s
  FLASH_ATTN                  |   0.24ms ± 0.00ms |  8500572 tok/s
  FLASH_ATTN      (upstream)  |   0.27ms ± 0.00ms |  7520414 tok/s
  XFORMERS                    |   0.41ms ± 0.00ms |  4979442 tok/s

------------------------------------------------------------------------------------------------------------------------
Configuration: batch_size=2, seq_len=1024, num_heads=16, head_size=80, num_kv_heads=16
------------------------------------------------------------------------------------------------------------------------
  TORCH_SDPA                  |   0.24ms ± 0.00ms |  8566497 tok/s
  FLASH_ATTN      (upstream)  |   0.31ms ± 0.00ms |  6556944 tok/s
  XFORMERS                    |   0.45ms ± 0.01ms |  4575754 tok/s

------------------------------------------------------------------------------------------------------------------------
Configuration: batch_size=2, seq_len=2048, num_heads=16, head_size=64, num_kv_heads=16
------------------------------------------------------------------------------------------------------------------------
  TORCH_SDPA                  |   0.39ms ± 0.00ms | 10412348 tok/s
  FLASH_ATTN                  |   0.46ms ± 0.00ms |  8970121 tok/s
  FLASH_ATTN      (upstream)  |   0.47ms ± 0.00ms |  8735247 tok/s
  XFORMERS                    |   0.61ms ± 0.01ms |  6716568 tok/s

------------------------------------------------------------------------------------------------------------------------
Configuration: batch_size=2, seq_len=2048, num_heads=16, head_size=80, num_kv_heads=16
------------------------------------------------------------------------------------------------------------------------
  TORCH_SDPA                  |   0.58ms ± 0.00ms |  7061736 tok/s
  FLASH_ATTN      (upstream)  |   0.60ms ± 0.00ms |  6837219 tok/s
  XFORMERS                    |   0.73ms ± 0.01ms |  5628773 tok/s

------------------------------------------------------------------------------------------------------------------------
Configuration: batch_size=2, seq_len=4096, num_heads=16, head_size=64, num_kv_heads=16
------------------------------------------------------------------------------------------------------------------------
  FLASH_ATTN      (upstream)  |   1.28ms ± 0.00ms |  6395585 tok/s
  TORCH_SDPA                  |   1.29ms ± 0.00ms |  6345274 tok/s
  FLASH_ATTN                  |   1.34ms ± 0.00ms |  6117482 tok/s
  XFORMERS                    |   1.40ms ± 0.00ms |  5864070 tok/s

------------------------------------------------------------------------------------------------------------------------
Configuration: batch_size=2, seq_len=4096, num_heads=16, head_size=80, num_kv_heads=16
------------------------------------------------------------------------------------------------------------------------
  FLASH_ATTN      (upstream)  |   1.77ms ± 0.00ms |  4631816 tok/s
  XFORMERS                    |   1.88ms ± 0.00ms |  4368585 tok/s
  TORCH_SDPA                  |   1.96ms ± 0.00ms |  4170571 tok/s


Failed runs: 6

Failure Details:
--------------------------------------------------
bs=1, seq_len=1024, heads=16, head_size=80, kv_heads=16, backend=FLASH_ATTN: This flash attention build does not support headdim not being a multiple of 32.
bs=1, seq_len=2048, heads=16, head_size=80, kv_heads=16, backend=FLASH_ATTN: This flash attention build does not support headdim not being a multiple of 32.
bs=1, seq_len=4096, heads=16, head_size=80, kv_heads=16, backend=FLASH_ATTN: This flash attention build does not support headdim not being a multiple of 32.
bs=2, seq_len=1024, heads=16, head_size=80, kv_heads=16, backend=FLASH_ATTN: This flash attention build does not support headdim not being a multiple of 32.
bs=2, seq_len=2048, heads=16, head_size=80, kv_heads=16, backend=FLASH_ATTN: This flash attention build does not support headdim not being a multiple of 32.
bs=2, seq_len=4096, heads=16, head_size=80, kv_heads=16, backend=FLASH_ATTN: This flash attention build does not support headdim not being a multiple of 32.

@Isotr0py
Copy link
Member

Isotr0py commented Sep 8, 2025

batch_size=1, 2
seq_len=1024, 2048, 4096
num_heads=16
head_size=64, 80
num_kv_heads=16
backend=FLASH_ATTN, FLASH_ATTN (upstream), TORCH_SDPA, XFORMERS

I prefer to collect more e2e usage benchmark results for reference instead of pure kernel benchmark results tbh (we all know FA is fastest at kernel level), because we only care about how ViT's FA takes effect for global performance here. For example, we could test some quite large model's ViTs like InternViT-6B etc.

ViT is quite small comparing to the text backbone in most of cases (InternVL is an exception with 6B ViT), which is unlikely to be a performance bottleneck, and the performance gain could be minor. But I'm still wondering if large ViT can benefit from upstream FA under e2e case.

If upstream FA breaks anything, I think we could easiy disable it? And upstream FA would not take effect unless user manually install it.

Does this mean that we need to ask user to uninstall upstream FA for disabling? I'm afraid that this could be an issue for downstream projects. Although FA is not a requirement for vLLM, it might have served as a requirement for downstream projects. 🤔

@wwl2755
Copy link
Contributor Author

wwl2755 commented Sep 8, 2025

ViT is quite small comparing to the text backbone in most of cases (InternVL is an exception with 6B ViT), which is unlikely to be a performance bottleneck, and the performance gain could be minor. But I'm still wondering if large ViT can benefit from upstream FA under e2e case.

I agree the bottleneck would be somewhere else (but I think it doesn't mean we should not address it? The improvement could show up if other bottlenecks is solved/reduced.) I will do some e2e benchmarks with larger VIT.

Does this mean that we need to ask user to uninstall upstream FA for disabling? I'm afraid that this could be an issue for downstream projects. Although FA is not a requirement for vLLM, it might have served as a requirement for downstream projects. 🤔

Uh, here my meaning is that we could disable the usage of upstream FA in vllm easily (by setting is_upstream_fa_supported to False, comment it and wait until fix). Then for downstream users, they won't use upstream FA in vllm even if their env has it.

@wwl2755
Copy link
Contributor Author

wwl2755 commented Sep 8, 2025

I just checked OpenGVLab/InternViT-6B-224px but I think InternVisionModel architecture is not supported yet. And I cannot run e2e on OpenGVLab/InternVL3_5-38B, which has a 5.5B VIT, because I don't have such a large space.

However, I ran a profiler and we can see the performance increase in each step.

Plus, I also find there are several rearrange that could be combined in the flash_attn path, which could improve 20 micro sec (~10%) in each flash_attn.

image image

@wwl2755
Copy link
Contributor Author

wwl2755 commented Sep 9, 2025

Probably 2-3% e2e improvement using vllm's flash_attn

vllm serve OpenGVLab/InternVL3_5-4B --trust-remote-code

vllm bench serve \
  --backend openai-chat \
  --endpoint-type openai-chat \
  --model OpenGVLab/InternVL3_5-4B \
  --endpoint /v1/chat/completions \
  --dataset-name hf \
  --dataset-path lmarena-ai/VisionArena-Chat \
  --hf-split train \
  --num-prompts 20 \
  --seed 40
  
# FLASH_ATTN
============ Serving Benchmark Result ============
Successful requests:                     20        
Benchmark duration (s):                  55.92     
Total input tokens:                      2896      
Total generated tokens:                  2189      
Request throughput (req/s):              0.36      
Output token throughput (tok/s):         39.14     
Total Token throughput (tok/s):          90.93     
---------------Time to First Token----------------
Mean TTFT (ms):                          52301.23  
Median TTFT (ms):                        52574.63  
P99 TTFT (ms):                           54208.64  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          18.13     
Median TPOT (ms):                        21.04     
P99 TPOT (ms):                           42.43     
---------------Inter-token Latency----------------
Mean ITL (ms):                           23.10     
Median ITL (ms):                         13.91     
P99 ITL (ms):                            195.22    
==================================================

# XFORMER
============ Serving Benchmark Result ============
Successful requests:                     20        
Benchmark duration (s):                  56.83     
Total input tokens:                      2896      
Total generated tokens:                  2186      
Request throughput (req/s):              0.35      
Output token throughput (tok/s):         38.47     
Total Token throughput (tok/s):          89.43     
---------------Time to First Token----------------
Mean TTFT (ms):                          53190.24  
Median TTFT (ms):                        53476.69  
P99 TTFT (ms):                           55115.90  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          18.18     
Median TPOT (ms):                        21.05     
P99 TPOT (ms):                           42.44     
---------------Inter-token Latency----------------
Mean ITL (ms):                           23.05     
Median ITL (ms):                         13.76     
P99 ITL (ms):                            193.57    
==================================================

Copy link

mergify bot commented Sep 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wwl2755.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
multi-modality Related to multi-modality (#4194) needs-rebase qwen Related to Qwen models rocm Related to AMD ROCm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants