Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,3 +1362,54 @@ def test_eagle_with_sliding_window():
# there will be no matched prefix.
assert len(computed_blocks.blocks[0]) == 0
assert num_tokens == 0


def test_delay_free_blocks():
"""Test delay free blocks"""
block_pool = BlockPool(
num_gpu_blocks=7,
delay_batch_size=16,
enable_caching=True,
)

assert block_pool.check_blocks_enough(6)
# blocks:
# list[0]: block1
# list[1] : block2 block3
# list[2] : block4 block5 block6
blocks: list[list[KVCacheBlock]] = []
blocks.append(block_pool.get_new_blocks(1))
blocks.append(block_pool.get_new_blocks(2))
blocks.append(block_pool.get_new_blocks(3))

for item in blocks:
ordered_item = reversed(item)
block_pool.free_blocks(ordered_item)
assert len(block_pool.delay_free_blocks) == 3
assert len(block_pool.delay_free_blocks[2]) == 3

assert block_pool.check_blocks_enough(6)

block_ids = [6, 3, 5, 1, 2, 4]
cur_blocks = block_pool.get_new_blocks(6)
for i in range(len(cur_blocks)):
assert block_ids[i] == cur_blocks[i].block_id

block_pool = BlockPool(
num_gpu_blocks=7,
delay_batch_size=4,
enable_caching=True,
)

blocks.clear()
blocks.append(block_pool.get_new_blocks(1))
blocks.append(block_pool.get_new_blocks(1))
blocks.append(block_pool.get_new_blocks(1))
blocks.append(block_pool.get_new_blocks(1))

for item in blocks:
ordered_item = reversed(item)
block_pool.free_blocks(ordered_item)

assert len(block_pool.delay_free_blocks) == 0
assert block_pool.get_num_free_blocks() == 6
2 changes: 2 additions & 0 deletions vllm/config/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ class CacheConfig:
attention metadata for eligible layers to be overridden with metadata
necessary for implementing this optimization in some models (e.g. Gemma3n)
"""
delay_batch_size: int = 16
"""The batch size of the delayed release"""

def compute_hash(self) -> str:
"""
Expand Down
57 changes: 45 additions & 12 deletions vllm/v1/core/block_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ class BlockPool:
def __init__(
self,
num_gpu_blocks: int,
delay_batch_size: int,
enable_caching: bool,
enable_kv_cache_events: bool = False,
):
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
self.num_gpu_blocks = num_gpu_blocks
self.delay_batch_size = delay_batch_size
self.enable_caching = enable_caching
# All kv-cache blocks.
self.blocks: list[KVCacheBlock] = [
Expand All @@ -46,6 +48,7 @@ def __init__(
# list of free blocks (including eviction candidates when caching is
# enabled).
self.free_block_queue = FreeKVCacheBlockQueue(self.blocks)
self.delay_free_blocks: list[list[KVCacheBlock]] = []

# {block_hash: {block ID: block}}. A cached block is
# a full block with a block hash that can be used for prefix caching.
Expand Down Expand Up @@ -172,8 +175,10 @@ def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
A list of new block.
"""
if num_blocks > self.get_num_free_blocks():
raise ValueError(
f"Cannot get {num_blocks} free blocks from the pool")
self._free_delay_blocks()
if num_blocks > self.get_num_free_blocks():
raise ValueError(
f"Cannot get {num_blocks} free blocks from the pool")

ret: list[KVCacheBlock] = self.free_block_queue.popleft_n(num_blocks)

Expand Down Expand Up @@ -240,22 +245,37 @@ def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:
self.free_block_queue.remove(block)
block.ref_cnt += 1

def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
def free_blocks(self,
ordered_blocks: Iterable[KVCacheBlock],
refresh: bool = False) -> None:
"""Free a list of blocks. The blocks should be ordered by their
eviction priority, where the first block will be evicted first.

Args:
ordered_blocks: A list of blocks to free ordered by their eviction
priority.
refresh: Default is False, refresh blocks.
"""
# Materialize the iterable to allow multiple passes.
blocks_list = list(ordered_blocks)
for block in blocks_list:
block.ref_cnt -= 1
self.free_block_queue.append_n([
block for block in blocks_list
if block.ref_cnt == 0 and not block.is_null
])
self.delay_free_blocks.append(list(ordered_blocks))
if len(self.delay_free_blocks) >= self.delay_batch_size or refresh:
self._free_delay_blocks()

def _free_delay_blocks(self):
if not self.delay_free_blocks:
return
max_len = max(len(req) for req in self.delay_free_blocks)
for col in range(max_len - 1, -1, -1):
for row in range(len(self.delay_free_blocks)):
blocks = self.delay_free_blocks[row]
blocks_len = len(blocks)
if col < blocks_len:
block = blocks[blocks_len - col - 1]
block.ref_cnt -= 1
if block.ref_cnt == 0:
self.free_block_queue.append(block)

self.delay_free_blocks.clear()

def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
Expand All @@ -266,7 +286,7 @@ def reset_prefix_cache(self) -> bool:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
num_used_blocks = self.num_gpu_blocks - self.get_num_free_blocks()
num_used_blocks = self.num_gpu_blocks - self.get_num_free_blocks(True)
if num_used_blocks != 1: # The null block is always marked as used
logger.warning(
"Failed to reset prefix cache because some "
Expand All @@ -287,14 +307,27 @@ def reset_prefix_cache(self) -> bool:

return True

def get_num_free_blocks(self) -> int:
def get_num_free_blocks(self, refresh: bool = False) -> int:
"""Get the number of free blocks in the pool.

Returns:
The number of free blocks.
"""
if refresh:
self._free_delay_blocks()
return self.free_block_queue.num_free_blocks

def check_blocks_enough(self, needBlocks: int) -> bool:
"""Check if there is enough for allocation.
When it is not enough, actively clean up the delayed requests.

Returns:
bool: True if there are enough blocks,
False otherwise.
"""
return self.get_num_free_blocks(
) >= needBlocks or self.get_num_free_blocks(refresh=True) >= needBlocks

def get_usage(self) -> float:
"""Get the KV cache usage.

Expand Down
5 changes: 3 additions & 2 deletions vllm/v1/core/kv_cache_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ def __init__(
self.max_model_len = max_model_len
self.enable_caching = enable_caching

self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching,
enable_kv_cache_events)
self.block_pool = BlockPool(kv_cache_config.num_blocks,
kv_cache_config.delay_batch_size,
enable_caching, enable_kv_cache_events)

# Needs special handling for find_longest_cache_hit if eagle is enabled
self.use_eagle = use_eagle
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def allocate_slots(
num_encoder_tokens=num_encoder_tokens,
)

if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
if not self.block_pool.check_blocks_enough(num_blocks_to_allocate):
# Cannot allocate new blocks
return None

Expand Down
17 changes: 11 additions & 6 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,27 +1100,32 @@
The generated KVCacheConfigs
"""
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
kvCacheConfig: KVCacheConfig = None

Check failure on line 1103 in vllm/v1/core/kv_cache_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "None", variable has type "KVCacheConfig") [assignment]

Check failure on line 1103 in vllm/v1/core/kv_cache_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "None", variable has type "KVCacheConfig") [assignment]

Check failure on line 1103 in vllm/v1/core/kv_cache_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "None", variable has type "KVCacheConfig") [assignment]

Check failure on line 1103 in vllm/v1/core/kv_cache_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "None", variable has type "KVCacheConfig") [assignment]

Check failure on line 1103 in vllm/v1/core/kv_cache_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "None", variable has type "KVCacheConfig") [assignment]

Check failure on line 1103 in vllm/v1/core/kv_cache_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "None", variable has type "KVCacheConfig") [assignment]

Check failure on line 1103 in vllm/v1/core/kv_cache_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "None", variable has type "KVCacheConfig") [assignment]

Check failure on line 1103 in vllm/v1/core/kv_cache_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "None", variable has type "KVCacheConfig") [assignment]

Check failure on line 1103 in vllm/v1/core/kv_cache_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "None", variable has type "KVCacheConfig") [assignment]

Check failure on line 1103 in vllm/v1/core/kv_cache_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "None", variable has type "KVCacheConfig") [assignment]
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
unify_hybrid_kv_cache_specs(kv_cache_spec)

if is_kv_cache_type_attention_free(kv_cache_spec):
# This returns a kv_cache config with 0 kv_cache groups and 1 block
# to allow for the KVCache manager to handle attention free models.
return _get_kv_cache_config_attention_free()
kvCacheConfig = _get_kv_cache_config_attention_free()
elif is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
available_memory)
kvCacheConfig = _get_kv_cache_config_uniform_type(
vllm_config, kv_cache_spec, available_memory)
elif is_kv_cache_page_size_uniform(kv_cache_spec):
# Model contains multiple attention types, but KV cache of all layers
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return _get_kv_cache_config_uniform_page_size(vllm_config,
kv_cache_spec,
available_memory)
kvCacheConfig = _get_kv_cache_config_uniform_page_size(
vllm_config, kv_cache_spec, available_memory)
if kvCacheConfig:
if vllm_config.cache_config:
delay_batch_size = vllm_config.cache_config.delay_batch_size
kvCacheConfig.delay_batch_size = delay_batch_size
return kvCacheConfig

raise NotImplementedError

Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/kv_cache_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,7 @@ class KVCacheConfig:
see `_get_kv_cache_config_uniform_page_size` for more details.
"""
kv_cache_groups: list[KVCacheGroupSpec]
"""
The batch size of the delayed release
"""
delay_batch_size: int = 16