Skip to content

Conversation

BoyuanFeng
Copy link
Contributor

@BoyuanFeng BoyuanFeng commented Jul 24, 2025

Currently we rely on memory profiling to estimate the kv cache memory.

However, there are two issues:

  1. The estimation does not include CUDAGraph memory since the estimation happens before CUDAGraph capture. This leads to underestimation of memory consumptions and allocate too much KV cache memory, thus OOM later. For example, on LLAMA4 Maverick, it underestimates ~1.3 GB for CUDAGraph. So it OOM when gpu-memory-utilization=0.99
  2. Some users may want to fully utilize gpu memory. Currently they have to trial-and-error with different gpu-memory-utilization config.

This PR provides a kv_cache_memory config such that users can specify the kv cache memory size in bytes.

By default, kv_cache_memory is None. We still rely on memory profiling to estimate the kv cache memory. But memory_profiling suggests an optimal kv_cache_memory config which users can add in the future runs. For example, with vllm bench latency --model "google/gemma-3-4b-it", there would a log:

Free memory on device (177.66/178.36 GiB) on startup. Desired GPU memory utilization is (0.9, 160.53 GiB). Actual usage is 8.58 GiB for weight, 9.96 GiB for peak activation, 0.07 GiB for non-torch memory, and 0.65 GiB for CUDAGraph memory. Replace gpu_memory_utilization config with `--kv-cache-memory=151519866880` to fit into requested memory, or `--kv-cache-memory=169914314752` to fully utilize gpu memory. Current kv cache memory in use is 152379699200 bytes.

In the future runs, users could run with the suggested config as vllm bench latency --model "google/gemma-3-4b-it" --kv-cache-memory=169914314752. This would skip the memory profiling and follow user-specified kv cache memory size.

Tested on: meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8, meta-llama/Llama-4-Scout-17B-16E-Instruct,
Qwen/Qwen3-30B-A3B, and google/gemma-3-4b-it.

Closes: #19480

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request improves the estimation of available KV Cache memory by accounting for CUDAGraph and non-torch memory, which helps prevent out-of-memory errors. The changes look good and correctly address the described issues. I've added a couple of suggestions to improve robustness and maintainability: one to prevent a potential division-by-zero error, and another to replace a magic number with a named constant for better clarity.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think originally it was designed on purpose that gpu_memory_utilization does not take cuda graph memory into account so that users can tune this parameter given their cuda graph configurations (whether to turn it on or not, and how many graphs to capture)

This means vllm serve meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 --disable-log-requests -tp 8 --max-model-len 4096 --gpu-memory-utilization 0.99 --trust-remote-code is most definitely going to OOM by design, but not if the user adds --enforce-eager

While I do much agree that adding this sort of estimation should generally speaking improves UX, it's indeed a change of behavior of how our memory profiling works, so I'd like others to chime in too. cc @youkaichao @WoosukKwon

@NickLucche
Copy link
Collaborator

NickLucche commented Jul 25, 2025

Gemma's overestimation looks considerable

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like @ywang96 mentioned, shall we introduce a flag which disables the profile, and map the gpu_mem_utilization to the full HBM, and powerful user can just play with it?

The benefits will be 1) simplify the logic of start up, smaller chance to fail, 2) faster start up, 3) more option for powerful users to tweak.

@yinghai
Copy link
Contributor

yinghai commented Jul 26, 2025

I agree with @ywang96 that we are just adding another guesstimation on top of gpu_memory_utilization.

@mergify mergify bot added the frontend label Aug 12, 2025
@BoyuanFeng BoyuanFeng marked this pull request as draft August 12, 2025 05:32
Copy link

mergify bot commented Aug 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @BoyuanFeng.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 13, 2025
@BoyuanFeng BoyuanFeng changed the title improve estimation of available KV Cache memory Allow users to specify kv cache memory size Aug 13, 2025
@BoyuanFeng BoyuanFeng marked this pull request as ready for review August 13, 2025 22:47
@mergify mergify bot removed the needs-rebase label Aug 13, 2025
Signed-off-by: Boyuan Feng <[email protected]>
@BoyuanFeng BoyuanFeng force-pushed the bf/memory_utilization branch from 15dada4 to 9b16c00 Compare August 21, 2025 21:29
@mergify mergify bot removed tpu Related to Google TPUs needs-rebase labels Aug 21, 2025
necessary for implementating this optimization in some models (e.g. Gemma3n)
"""

kv_cache_memory: Optional[int] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this name with unit, some thing like kv_cache_memory_gb? Also why not make it a float?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kv_cache_memory_bytes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's bytes, then no need to use float.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes renamed as kv_cache_memory_bytes

Copy link

mergify bot commented Aug 26, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @BoyuanFeng.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 26, 2025
@mergify mergify bot removed the needs-rebase label Aug 26, 2025
Copy link
Member

@hmellor hmellor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@hmellor hmellor enabled auto-merge (squash) September 11, 2025 09:59
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 11, 2025
@hmellor hmellor merged commit 94e6b2d into vllm-project:main Sep 11, 2025
53 of 54 checks passed
skyloevil pushed a commit to skyloevil/vllm that referenced this pull request Sep 13, 2025
Signed-off-by: Boyuan Feng <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
dsxsteven pushed a commit to dsxsteven/vllm_splitPR that referenced this pull request Sep 15, 2025
Signed-off-by: Boyuan Feng <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Signed-off-by: Boyuan Feng <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models llama Related to Llama models multi-modality Related to multi-modality (#4194) new-model Requests to new models performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm speculative-decoding structured-output tool-calling v1
Projects
Status: Done
Status: Done
Development

Successfully merging this pull request may close these issues.

[Bug]: Compile inductor / CUDA Graph build before the memory profiling
7 participants