Skip to content

Commit fdcf64d

Browse files
authored
[V1] Clarify input processing and multimodal feature caching logic (#13211)
1 parent 578087e commit fdcf64d

File tree

4 files changed

+45
-27
lines changed

4 files changed

+45
-27
lines changed

vllm/v1/engine/core.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from vllm.v1.core.scheduler import Scheduler
2121
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
2222
EngineCoreRequestType)
23-
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
23+
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
2424
from vllm.v1.executor.abstract import Executor
2525
from vllm.v1.request import Request, RequestStatus
2626
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
@@ -65,7 +65,7 @@ def __init__(
6565
log_stats=self.log_stats,
6666
)
6767

68-
self.mm_input_mapper_server = MMInputMapperServer(
68+
self.mm_input_cache_server = MMInputCacheServer(
6969
vllm_config.model_config)
7070

7171
def _initialize_kv_caches(self,
@@ -102,13 +102,13 @@ def add_request(self, request: EngineCoreRequest):
102102
"""Add request to the scheduler."""
103103

104104
if request.mm_hashes is not None:
105-
# Here, if hash exists for an image, then it will be fetched
106-
# from the cache, else it will be added to the cache.
107-
# Note that the cache here is mirrored with the client side of the
108-
# MM mapper, so anything that has a hash must have a HIT cache
109-
# entry here as well.
105+
# Here, if hash exists for a multimodal input, then it will be
106+
# fetched from the cache, else it will be added to the cache.
107+
# Note that the cache here is mirrored with the client cache, so
108+
# anything that has a hash must have a HIT cache entry here
109+
# as well.
110110
assert request.mm_inputs is not None
111-
request.mm_inputs = self.mm_input_mapper_server.process_inputs(
111+
request.mm_inputs = self.mm_input_cache_server.get_and_update(
112112
request.mm_inputs, request.mm_hashes)
113113

114114
req = Request.from_engine_core_request(request)

vllm/v1/engine/mm_input_mapper.py renamed to vllm/v1/engine/mm_input_cache.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,18 @@
1010

1111
logger = init_logger(__name__)
1212

13-
# The idea of MM preprocessor caching is based on having a client and a server,
14-
# where the client executes in the frontend process (=P0) and the server in the
15-
# core process (=P1).
13+
# The idea of multimodal preprocessing caching is based on having a client and
14+
# a server, where the client executes in the frontend process (=P0) and the
15+
# server in the core process (=P1).
1616
#
17-
# -- Client: Executes the MM mapper and performs caching of the results.
18-
# -- Server: Performs caching of the results
17+
# -- Client:
18+
# - Apply legacy input_mapper (if one exists) to generate MultiModalKwargs.
19+
# - Perform caching of the generated MultiModalKwargs.
20+
# - This client can be deprecated once all mutimodal models migrate to use
21+
# merged preprocessor with built-in caching functionality.
22+
#
23+
# -- Server:
24+
# - Perform caching of the received MultiModalKwargs.
1925
#
2026
# The caching for both client and server is mirrored/similar, and this allows us
2127
# to avoid the serialization of "mm_inputs" (like pixel values) between
@@ -27,7 +33,9 @@
2733
MM_CACHE_SIZE = 256
2834

2935

30-
class MMInputMapperClient:
36+
# TODO(ywang96): Deprecate this class once all multimodal models migrate to use
37+
# merged preprocessor with built-in caching functionality.
38+
class MMInputCacheClient:
3139

3240
def __init__(
3341
self,
@@ -54,7 +62,8 @@ def cache_hit_ratio(self, steps):
5462
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
5563
self.mm_cache_hits / self.mm_cache_total)
5664

57-
# TODO: Support modalities beyond image.
65+
# NOTE: process_inputs only supports image inputs since all multimodal
66+
# models with other modalities have migrated to use merged preprocessor.
5867
def process_inputs(
5968
self,
6069
mm_data: MultiModalDataDict,
@@ -95,7 +104,7 @@ def process_inputs(
95104
# Reuse precomputed input (for merged preprocessor)
96105
mm_input = precomputed_mm_inputs[input_id]
97106
else:
98-
# Apply MM mapper
107+
# Apply legacy input_mapper
99108
mm_input = self.multi_modal_input_mapper(
100109
{"image": [image_inputs[input_id]]},
101110
mm_processor_kwargs=mm_processor_kwargs,
@@ -114,13 +123,13 @@ def process_inputs(
114123
return ret_inputs
115124

116125

117-
class MMInputMapperServer:
126+
class MMInputCacheServer:
118127

119128
def __init__(self, model_config):
120129
self.use_cache = not model_config.disable_mm_preprocessor_cache
121130
self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)
122131

123-
def process_inputs(
132+
def get_and_update(
124133
self,
125134
mm_inputs: List[Optional[MultiModalKwargs]],
126135
mm_hashes: List[str],

vllm/v1/engine/processor.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from vllm.sampling_params import SamplingParams
1818
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
1919
from vllm.v1.engine import EngineCoreRequest
20-
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
20+
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
2121

2222

2323
class Processor:
@@ -46,7 +46,7 @@ def __init__(
4646
model_config)
4747

4848
# Multi-modal (huggingface) input mapper
49-
self.mm_input_mapper_client = MMInputMapperClient(model_config)
49+
self.mm_input_cache_client = MMInputCacheClient(model_config)
5050

5151
# Multi-modal hasher (for images)
5252
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
@@ -106,16 +106,24 @@ def process_inputs(
106106
assert priority == 0, "vLLM V1 does not support priority at the moment."
107107
assert trace_headers is None, "vLLM V1 does not support tracing yet."
108108

109-
# Process inputs.
109+
# Process inputs, which includes:
110+
# 1. Tokenize text prompt, with LoRA request if one exists.
111+
# 2. For multimodal models with a merged preprocessor, preprocess
112+
# multimodal data and expand prompt token ids accordingly.
113+
# 3. Apply prompt adapter to prompt token ids if one exists.
110114
preprocessed_inputs = self.input_preprocessor.preprocess(
111115
prompt,
112116
request_id=request_id,
113117
lora_request=lora_request,
114118
prompt_adapter_request=prompt_adapter_request,
115119
)
120+
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
121+
122+
# Process prompt and prompt token ids.
123+
# Only applicable to multimodal models with legacy input processor.
116124
processed_inputs = self.input_processor(preprocessed_inputs)
125+
117126
self._validate_model_inputs(processed_inputs)
118-
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
119127

120128
if is_encoder_decoder_inputs(processed_inputs):
121129
decoder_inputs = SingletonInputsAdapter(
@@ -200,8 +208,8 @@ def process_inputs(
200208
key=lambda mm_input: modality_order_dict[list(
201209
mm_input.modalities)[0]])
202210

203-
# Apply mm input cache update (and input mapper if necessary).
204-
sorted_mm_inputs = self.mm_input_mapper_client.process_inputs(
211+
# Apply mm input cache update and legacy input mapper if one exists.
212+
sorted_mm_inputs = self.mm_input_cache_client.process_inputs(
205213
mm_data=decoder_mm_data,
206214
mm_hashes=sorted_mm_hashes,
207215
mm_processor_kwargs=decoder_inputs.mm_processor_kwargs,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
2828
FlashAttentionMetadata)
2929
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
30-
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
30+
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
3131
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
3232
KVCacheSpec)
3333
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
@@ -95,9 +95,10 @@ def __init__(
9595
self.mm_registry = MULTIMODAL_REGISTRY
9696
self.uses_mrope = model_config.uses_mrope
9797

98-
# NOTE: Initialized input mapper is only used for processing dummy
98+
# NOTE: Initialized client is only used for processing dummy
9999
# multimodal data into multimodal kwargs for GPU memory profiling.
100-
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
100+
# Only applicable to multimodal models with legacy input mapper.
101+
self.mm_input_mapper_profiling = MMInputCacheClient(self.model_config)
101102
self.mm_input_mapper_profiling.use_cache = False
102103

103104
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(

0 commit comments

Comments
 (0)