-
-
Notifications
You must be signed in to change notification settings - Fork 10.4k
[Frontend][TPU] Add TPU default max-num-batched-tokens based on device name #17508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Chenyaaang <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Chenyaaang <[email protected]>
vllm/engine/arg_utils.py
Outdated
'v5e': 512, | ||
'v5p': 256, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is p smaller than e? It has more flops and bw.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, see the above table, the peak AI of v5p is smaller than v5e.
Signed-off-by: Chenyaaang <[email protected]>
UsageContext.OPENAI_API_SERVER: { | ||
'V6E': 1024, | ||
'V5E': 512, | ||
'V5P': 256, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do still worry about these smaller sizes for multimodal models, since this is smaller than the usual image size in tokens, which will cause errors for the user. Maybe we can expand this for multimodal in a separate pr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, for multi-modal, is it required that we put all image tokens within 1 batch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we need to be able to put the largest single mm item in a single token batch. If we split the mm item, then we need slices that force recompilation on TPU
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, thanks for the explanation.
…e name (vllm-project#17508) Signed-off-by: Chenyaaang <[email protected]> Signed-off-by: Mu Huai <[email protected]>
…e name (vllm-project#17508) Signed-off-by: Chenyaaang <[email protected]>
…e name (vllm-project#17508) Signed-off-by: Chenyaaang <[email protected]> Signed-off-by: Yuqi Zhang <[email protected]>
FIX #17145
This pr does the following 2 things:
--max-model-len
. As it will fall back to use the default max length from model's config json file, which is usually much larger than the actual use case, and might cause OOM on further steps (such as initializing kv cache).The current default max-num-batched-tokens values are 2048 for online serving, 8192 for offline serving, which are too large for TPUs. We re-calculate the batch size from TPU hardware specs, ie, peak compute flops and HBM bandwidth. We set the batch size to be higher than the threshold to arrive at peak Arithmetic Intensity.
For example, for v6e chip, the peak compute flops (bf16) is 918TFlops and HBM bandwidth is 1640GBps, so the peak arithmetic intensity is 560. The arithemetic intensity of a matmul of 2 matrix with size (B, D, fp16) and (D, F, fp16) is
2BDF/2(BD + DF + BF)
can be approximately equal to B, given D,F (hidden dimension, intermediate dimension) are much larger than B. So max-num-batched-tokens should be higher than 560. Same for v5e and v5p.ps: If quantized, here we assume the weights and activations are both int8, the peak flops for int8 is also twice the peak flops for bf16, so the final result doesn't change.