-
-
Notifications
You must be signed in to change notification settings - Fork 10k
[Multi Modal] Add FA3 in VIT #24347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Multi Modal] Add FA3 in VIT #24347
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request enables FlashAttention 3 in ViT, which is a great enhancement. The implementation correctly falls back to the upstream FlashAttention library when the head dimension is not a multiple of 32, which is necessary for certain models. My review includes a suggestion to optimize the import of the FlashAttention function to avoid performance overhead in the forward pass.
vllm/attention/layer.py
Outdated
if self.head_size % 32 != 0: | ||
# import from upstream flash_attn | ||
from flash_attn import flash_attn_varlen_func | ||
else: | ||
from vllm.vllm_flash_attn import flash_attn_varlen_func | ||
|
||
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len, | ||
step=q_len, | ||
dtype=torch.int32, | ||
device=query.device) | ||
cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len, | ||
step=kv_len, | ||
dtype=torch.int32, | ||
device=key.device) | ||
|
||
out = flash_attn_varlen_func( | ||
query.flatten(0, 1), | ||
key.flatten(0, 1), | ||
value.flatten(0, 1), | ||
cu_seqlens_q=cu_seqlens_q, | ||
cu_seqlens_k=cu_seqlens_k, | ||
max_seqlen_q=q_len, | ||
max_seqlen_k=kv_len, | ||
softmax_scale=self.scale, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Placing imports inside a method on a hot path, like forward
, can introduce significant performance overhead due to repeated module lookups. It's better to perform the import only once.
A good approach here would be to lazily initialize and cache the flash_attn_varlen_func
on the first call. This avoids the import overhead on subsequent calls. While the ideal place for this logic is in the __init__
method, lazy initialization within forward
is a great improvement that keeps the changes within this method.
if not hasattr(self, "_flash_attn_varlen_func"):
if self.head_size % 32 != 0:
# import from upstream flash_attn
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
self._flash_attn_varlen_func = flash_attn_varlen_func
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
step=q_len,
dtype=torch.int32,
device=query.device)
cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
step=kv_len,
dtype=torch.int32,
device=key.device)
out = self._flash_attn_varlen_func(
query.flatten(0, 1),
key.flatten(0, 1),
value.flatten(0, 1),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=kv_len,
softmax_scale=self.scale,
)
vllm/attention/layer.py
Outdated
if self.head_size % 32 != 0: | ||
# import from upstream flash_attn | ||
from flash_attn import flash_attn_varlen_func |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, I think get_attn_backend
won't and shouldn't return FA backend if head_size incompatible. Can you check if this codepath is actually active?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right. Codepath doesn't work in this commit. Thanks for pointing out!
In my local commit, I tested by hard-coding the supported_head_sizes()
in https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/flash_attn.py#L50 and it works fine.
In this case, I think we may have different options to handle the cases where head_size % 32 != 0
:
- Still use xformer. -> we just delete this part and leave for future refactor, or until vllm's native fork change its settings.
- Change the
supported_head_sizes()
to multiple of 8. -> This changes may affects some other places, not sure if it is good.
2.1 refactor MHA to useget_vit_attn_backend
-> vit_attn has fewer choices so it should be easier to handle than option 2
WDYT? And for those head_size % 32 ==0
, I think it should be fine as along as we feel good with the precision.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't this be an issue since upstream flash_attn
is not a dependency of vLLM?
If so I think it's probably a better idea to check this at model init time within get_vit_attn_backend
. I'm thinking:
- If a head size is supported by vLLM FA, use vLLM FA
- If a head size is not supported by vLLM FA but upstream is detected, use upstream FA.
- If a head size is not supported by vLLM FA but upstream FA is not installed, we simply fall back to xformers, but print something like
"Using xformers for ViT attention backend. To use flash attention for ViT please install flash_attn"
and get_vit_attn_backend
should be responsible for doing the decision making here and return the true attention backend to use. We can simply return another boolean to get_vit_attn_backend
from differentiate upstream FA and vLLM FA if that's the concern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_vit_attn_backend
should be responsible for doing the decision making here
Totally agree with the fallback algorithm. It's pretty much what I plan as option 2.1. The current concern is that now MHA attention is still using get_attn_backend
. I'm not sure whether there are any other scenario/corner cases (except VIT) would trigger this code path. If there isn't any, I can refactor the get_vit_attn_backend
and use it in the MHA instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After I scan the referenced files of MHA, they all seems to be VIT related. So I think it should be fine changing to use get_vit_attn_backend
there.
I will update the PR soon.
Here is a detailed lm_eval result of OpenGVLab/InternVL3_5-4B(head_dim is 64)
_Backend.FLASH_ATTN_VLLM_V1
_Backend.XFORMERS
llava-hf/llava-onevision-qwen2-7b-ov-hf(head_dim is 72)
_Backend.FLASH_ATTN_VLLM_V1 (from upstream flash_attn)
_Backend.XFORMERS
|
@ywang96 @Isotr0py @DarkLight1337 Gently ping for a quick review if you have some time to see whether this refactor makes sense to you. I will work on changing the function interfaces in related tests at the same time. |
Can you add a benchmark comparison? |
@david6666666 Here is the benchmark.
|
vllm/platforms/cuda.py
Outdated
if cls.has_device_capability(80): | ||
if head_size % 32 == 0: | ||
# Use vllm-flash-attn | ||
return _Backend.FLASH_ATTN, False | ||
if head_size % 32 != 0: | ||
from transformers.utils import is_flash_attn_2_available | ||
if is_flash_attn_2_available(): | ||
# Use upstream FA | ||
return _Backend.FLASH_ATTN, True | ||
else: | ||
# Fallback to XFORMERS | ||
logger.warning_once( | ||
"Using xformers for ViT attention backend. " | ||
"To use flash attention for ViT" | ||
"please install flash_attn") | ||
return _Backend.XFORMERS, False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if cls.has_device_capability(80): | |
if head_size % 32 == 0: | |
# Use vllm-flash-attn | |
return _Backend.FLASH_ATTN, False | |
if head_size % 32 != 0: | |
from transformers.utils import is_flash_attn_2_available | |
if is_flash_attn_2_available(): | |
# Use upstream FA | |
return _Backend.FLASH_ATTN, True | |
else: | |
# Fallback to XFORMERS | |
logger.warning_once( | |
"Using xformers for ViT attention backend. " | |
"To use flash attention for ViT" | |
"please install flash_attn") | |
return _Backend.XFORMERS, False | |
if cls.has_device_capability(80): | |
is_default_fa_supported = is_attn_backend_supported( | |
FLASH_ATTN_V1, head_size, dtype, | |
allow_import_error=False) | |
is_upstream_fa_supported = is_flash_attn_2_available() | |
if is_default_fa_supported: | |
return _Backend.FLASH_ATTN, False | |
elif is_upstream_fa_supported: | |
return _Backend.FLASH_ATTN, True | |
return _Backend.XFORMERS, False |
head_size % 32 == 0
maybe not a robust solution, because fork FA's support head_size might change in the future. I think we can use is_attn_backend_supported
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
Hmmm, seems the benefit from FA is still quite minor for ViT (this is same as my benchmark conclusion last time), I'm hesitant if it's worthwile to introduce extra maintenance from upstream FA... 🤔 |
Yes, The improvement seems subtle. I'm testing some larger image to see if that makes a difference. However, I think we still should refactor this because previous mha is using the codepath of get_attention_backend(). After refactoring the get_vit_attn_backend, it should be easier to consolidate those individual models like qwen2-vl.
If upstream FA breaks anything, I think we could easiy disable it? |
Hmmm, I think the exact blocker for Qwen2-VL-style ViT consolidation is how to unify |
Here is more test result. I'm using upstream as What we compared:
|
Signed-off-by: wwl2755 <[email protected]>
Signed-off-by: wwl2755 <[email protected]>
Signed-off-by: wwl2755 <[email protected]>
Signed-off-by: wwl2755 <[email protected]>
I prefer to collect more e2e usage benchmark results for reference instead of pure kernel benchmark results tbh (we all know FA is fastest at kernel level), because we only care about how ViT's FA takes effect for global performance here. For example, we could test some quite large model's ViTs like InternViT-6B etc. ViT is quite small comparing to the text backbone in most of cases (InternVL is an exception with 6B ViT), which is unlikely to be a performance bottleneck, and the performance gain could be minor. But I'm still wondering if large ViT can benefit from upstream FA under e2e case.
Does this mean that we need to ask user to uninstall upstream FA for disabling? I'm afraid that this could be an issue for downstream projects. Although FA is not a requirement for vLLM, it might have served as a requirement for downstream projects. 🤔 |
I agree the bottleneck would be somewhere else (but I think it doesn't mean we should not address it? The improvement could show up if other bottlenecks is solved/reduced.) I will do some e2e benchmarks with larger VIT.
Uh, here my meaning is that we could disable the usage of upstream FA in vllm easily (by setting |
Probably 2-3% e2e improvement using vllm's flash_attn
|
This pull request has merge conflicts that must be resolved before it can be |
Purpose
vLLM's native FA does not support it because it forbids
FLASHATTENTION_DISABLE_UNEVEN_K
to padding. (http://github.com/vllm-project/flash-attention/blob/v2.6.2/setup.py#L223)So, we use vllm's native flash attention if appliable and upstream FA as a fallback.
Previous attempt and revert: #12435, #12355, #12445
get_vit_attn_backend()
so that it could decide the attention backend tyoe internally.Related to #23880, #23884
cc: @ywang96 @DarkLight1337 @Isotr0py
Test
whose head_dim is 72, not a multiple of 32.
whose head_dim is 64, a multiple of 32.
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.