@@ -843,6 +843,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
843
843
kv_cache_tensors = kv_cache_tensors ,
844
844
kv_cache_groups = create_kv_cache_group_specs (kv_cache_spec ,
845
845
grouped_layer_names ),
846
+ kv_bytes_per_block = len (kv_cache_tensors ) * page_size ,
846
847
)
847
848
848
849
num_tokens = num_blocks * vllm_config .cache_config .block_size
@@ -1003,6 +1004,7 @@ def _get_kv_cache_config_uniform_page_size(
1003
1004
num_blocks = num_blocks ,
1004
1005
kv_cache_tensors = kv_cache_tensors ,
1005
1006
kv_cache_groups = kv_cache_groups ,
1007
+ kv_bytes_per_block = len (kv_cache_tensors ) * page_size ,
1006
1008
)
1007
1009
1008
1010
min_block_size = min (
@@ -1021,7 +1023,10 @@ def _get_kv_cache_config_uniform_page_size(
1021
1023
1022
1024
1023
1025
def _get_kv_cache_config_attention_free () -> KVCacheConfig :
1024
- return KVCacheConfig (num_blocks = 1 , kv_cache_tensors = [], kv_cache_groups = [])
1026
+ return KVCacheConfig (num_blocks = 1 ,
1027
+ kv_cache_tensors = [],
1028
+ kv_cache_groups = [],
1029
+ kv_bytes_per_block = 0 )
1025
1030
1026
1031
1027
1032
def unify_hybrid_kv_cache_specs (kv_cache_spec : dict [str , KVCacheSpec ]):
@@ -1149,7 +1154,12 @@ def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]):
1149
1154
# first `num_blocks` blocks of the tensor.
1150
1155
min_num_blocks = min (kv_cache_config .num_blocks
1151
1156
for kv_cache_config in kv_cache_configs )
1157
+ kv_bytes_per_block = sum ([
1158
+ kv_cache_config .kv_bytes_per_block
1159
+ for kv_cache_config in kv_cache_configs
1160
+ ])
1152
1161
for kv_cache_config in kv_cache_configs :
1153
1162
kv_cache_config .num_blocks = min_num_blocks
1163
+ kv_cache_config .kv_bytes_per_block = kv_bytes_per_block
1154
1164
1155
1165
return kv_cache_configs
0 commit comments