From 0b6d358532cde6b9cafb65e11d2d655b56b8a7d3 Mon Sep 17 00:00:00 2001 From: Or Ozeri Date: Tue, 9 Sep 2025 12:47:15 +0300 Subject: [PATCH] v1: Set num_cpu_blocks on VllmConfig This commit sets the vllm_config.cache_config.num_cpu_blocks according to vllm_config.cache_config.swap_space. Signed-off-by: Or Ozeri --- vllm/v1/core/kv_cache_utils.py | 12 +++++++++++- vllm/v1/engine/core.py | 6 +++++- vllm/v1/kv_cache_interface.py | 2 ++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 248ad9cda7c2..dba17a1aa1a2 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -843,6 +843,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, kv_cache_tensors=kv_cache_tensors, kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec, grouped_layer_names), + kv_bytes_per_block=len(kv_cache_tensors) * page_size, ) num_tokens = num_blocks * vllm_config.cache_config.block_size @@ -1003,6 +1004,7 @@ def _get_kv_cache_config_uniform_page_size( num_blocks=num_blocks, kv_cache_tensors=kv_cache_tensors, kv_cache_groups=kv_cache_groups, + kv_bytes_per_block=len(kv_cache_tensors) * page_size, ) min_block_size = min( @@ -1021,7 +1023,10 @@ def _get_kv_cache_config_uniform_page_size( def _get_kv_cache_config_attention_free() -> KVCacheConfig: - return KVCacheConfig(num_blocks=1, kv_cache_tensors=[], kv_cache_groups=[]) + return KVCacheConfig(num_blocks=1, + kv_cache_tensors=[], + kv_cache_groups=[], + kv_bytes_per_block=0) def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): @@ -1149,7 +1154,12 @@ def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]): # first `num_blocks` blocks of the tensor. min_num_blocks = min(kv_cache_config.num_blocks for kv_cache_config in kv_cache_configs) + kv_bytes_per_block = sum([ + kv_cache_config.kv_bytes_per_block + for kv_cache_config in kv_cache_configs + ]) for kv_cache_config in kv_cache_configs: kv_cache_config.num_blocks = min_num_blocks + kv_cache_config.kv_bytes_per_block = kv_bytes_per_block return kv_cache_configs diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 922c06b44be8..6d82e50fd04f 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -207,7 +207,11 @@ def _initialize_kv_caches( for cfg in kv_cache_configs ]) num_gpu_blocks = kv_cache_configs[0].num_blocks - num_cpu_blocks = 0 + if kv_cache_configs[0].kv_bytes_per_block == 0: + num_cpu_blocks = 0 + else: + num_cpu_blocks = (int(vllm_config.cache_config.swap_space_bytes) // + kv_cache_configs[0].kv_bytes_per_block) scheduler_kv_cache_config = kv_cache_configs[0] # Initialize kv cache and warmup the execution diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index a3e4d393e4d2..e87267e7915a 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -264,3 +264,5 @@ class KVCacheConfig: see `_get_kv_cache_config_uniform_page_size` for more details. """ kv_cache_groups: list[KVCacheGroupSpec] + """The number of KV bytes per block""" + kv_bytes_per_block: int