Skip to content

Conversation

Chenyaaang
Copy link
Contributor

@Chenyaaang Chenyaaang commented May 1, 2025

FIX #17145

This pr does the following 2 things:

  1. Add a warning if user doesn't input --max-model-len. As it will fall back to use the default max length from model's config json file, which is usually much larger than the actual use case, and might cause OOM on further steps (such as initializing kv cache).
  2. Provide TPU specific default max-num-batched-tokens value based on devices.

The current default max-num-batched-tokens values are 2048 for online serving, 8192 for offline serving, which are too large for TPUs. We re-calculate the batch size from TPU hardware specs, ie, peak compute flops and HBM bandwidth. We set the batch size to be higher than the threshold to arrive at peak Arithmetic Intensity.

For example, for v6e chip, the peak compute flops (bf16) is 918TFlops and HBM bandwidth is 1640GBps, so the peak arithmetic intensity is 560. The arithemetic intensity of a matmul of 2 matrix with size (B, D, fp16) and (D, F, fp16) is 2BDF/2(BD + DF + BF) can be approximately equal to B, given D,F (hidden dimension, intermediate dimension) are much larger than B. So max-num-batched-tokens should be higher than 560. Same for v5e and v5p.

chip TFlops (bf16) HBM bandwidth Peak AI Online max-num-batched-tokens Offline max-num-batched-tokens
v6e 918 1640 560 1024 2048
v5e 197 819 240 512 1024
v5p 459 2765 166 256 512

ps: If quantized, here we assume the weights and activations are both int8, the peak flops for int8 is also twice the peak flops for bf16, so the final result doesn't change.

Copy link

github-actions bot commented May 1, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link

mergify bot commented May 1, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Chenyaaang.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 1, 2025
@mergify mergify bot removed the needs-rebase label May 1, 2025
Comment on lines 1467 to 1468
'v5e': 512,
'v5p': 256,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is p smaller than e? It has more flops and bw.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, see the above table, the peak AI of v5p is smaller than v5e.

@mergify mergify bot added the tpu Related to Google TPUs label May 2, 2025
Comment on lines +1465 to +1469
UsageContext.OPENAI_API_SERVER: {
'V6E': 1024,
'V5E': 512,
'V5P': 256,
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do still worry about these smaller sizes for multimodal models, since this is smaller than the usual image size in tokens, which will cause errors for the user. Maybe we can expand this for multimodal in a separate pr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, for multi-modal, is it required that we put all image tokens within 1 batch?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we need to be able to put the largest single mm item in a single token batch. If we split the mm item, then we need slices that force recompilation on TPU

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks for the explanation.

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label May 2, 2025
@vllm-bot vllm-bot merged commit 87baebe into vllm-project:main May 3, 2025
70 of 72 checks passed
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request May 14, 2025
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants