diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f7d72d26e045..f900c11f6f83 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -17,12 +17,16 @@ from vllm.attention.layer import Attention from vllm.config import VllmConfig from vllm.forward_context import get_forward_context, set_forward_context +from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sampling_params import SamplingType from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, PallasMetadata) +from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput @@ -102,7 +106,8 @@ def __init__( self.max_model_len = model_config.max_model_len self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) self.max_num_tokens = scheduler_config.max_num_batched_tokens - self.max_num_reqs = scheduler_config.max_num_seqs + self.max_num_reqs = _get_padded_batch_size( + scheduler_config.max_num_seqs) # Model-related. self.num_attn_layers = model_config.get_num_layers_by_block_type( @@ -113,8 +118,29 @@ def __init__( self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() - self.model: Optional[nn.Module] = None + # Multi-modal data support + self.input_registry = INPUT_REGISTRY + self.mm_registry = MULTIMODAL_REGISTRY + self.uses_mrope = model_config.uses_mrope + # TODO: Support M-RoPE (e.g, Qwen2-VL) + assert not self.uses_mrope, "TPU does not support M-RoPE yet." + encoder_compute_budget, encoder_cache_size = compute_encoder_budget( + model_config=model_config, + scheduler_config=scheduler_config, + ) + self.max_num_encoder_input_tokens = encoder_compute_budget + self.encoder_cache_size = encoder_cache_size + + # Lazy initialization + # self.model: nn.Module # Set after load_model + # KV caches for forward pass + self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] + # req_id -> (input_id -> encoder_output) + self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} + + # Request states. + self.requests: Dict[str, CachedRequestState] = {} # Persistent batch. self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, @@ -122,18 +148,9 @@ def __init__( max_num_blocks_per_req=self.max_num_blocks_per_req, device=self.device, pin_memory=self.pin_memory, - vocab_size=self.model_config.get_vocab_size(), + vocab_size=model_config.get_vocab_size(), ) - # Request states. - self.requests: Dict[str, CachedRequestState] = {} - - # req_id -> (input_id -> encoder_output) - self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} - - # KV caches for forward pass - self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] - # Cached torch/numpy tensors self.num_swaps = 2 self.cur_swap_id = 0 @@ -197,7 +214,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) - + self.encoder_cache.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -210,6 +227,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: if req_index is not None: removed_req_indices.append(req_index) + # Free the cached encoder outputs. + for req_id, input_id in scheduler_output.free_encoder_input_ids: + encoder_outputs = self.encoder_cache.get(req_id) + if encoder_outputs is not None: + encoder_outputs.pop(input_id, None) + if not encoder_outputs: + self.encoder_cache.pop(req_id, None) + # Remove the unscheduled requests from the persistent batch. # NOTE(woosuk): The unscheduled requests are either preempted requests # or running requests that are not scheduled in this step. We remove @@ -578,6 +603,92 @@ def _prepare_decode( input_positions=input_positions, attn_metadata=attn_metadata) + def _execute_encoder(self, scheduler_output: "SchedulerOutput"): + scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs + if not scheduled_encoder_inputs: + return + + # Batch the multi-modal inputs. + mm_inputs: List[MultiModalKwargs] = [] + req_input_ids: List[Tuple[str, int]] = [] + for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): + req_state = self.requests[req_id] + for input_id in encoder_input_ids: + mm_inputs.append(req_state.mm_inputs[input_id]) + req_input_ids.append((req_id, input_id)) + + # Batch mm inputs as much as we can: if a request in the batch has + # multiple modalities or a different modality than the previous one, + # we process it separately to preserve item order. + # FIXME(ywang96): This is a hacky way to deal with multiple modalities + # in the same batch while still being able to benefit from batching + # multimodal inputs. The proper solution should be reordering the + # encoder outputs. + grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs) + + encoder_outputs = [] + for grouped_mm_inputs in grouped_mm_inputs_list: + batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) + batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, + device=self.device) + + # Run the encoder. + # `curr_group_outputs` is either of the following: + # 1. A tensor of shape (num_items, feature_size, hidden_size) + # in case feature_size is fixed across all multimodal items. + # 2. A list or tuple (length: num_items) of tensors, each of shape + # (feature_size, hidden_size) in case the feature size is dynamic + # depending on the input multimodal items. + curr_group_outputs = self.model.get_multimodal_embeddings( + **batched_mm_inputs) + + for output in curr_group_outputs: + encoder_outputs.append(output) + + # Cache the encoder outputs. + for (req_id, input_id), output in zip(req_input_ids, encoder_outputs): + if req_id not in self.encoder_cache: + self.encoder_cache[req_id] = {} + self.encoder_cache[req_id][input_id] = output + + def _gather_req_encoder_outputs( + self, + req_id: str, + scheduler_output: "SchedulerOutput", + ) -> List[torch.Tensor]: + encoder_outputs: List[torch.Tensor] = [] + assert req_id is not None + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] + req_state = self.requests[req_id] + num_computed_tokens = req_state.num_computed_tokens + mm_positions = req_state.mm_positions + for i, pos_info in enumerate(mm_positions): + start_pos = pos_info["offset"] + num_encoder_tokens = pos_info["length"] + + # The encoder output is needed if the two ranges overlap: + # [num_computed_tokens, + # num_computed_tokens + num_scheduled_tokens) and + # [start_pos, start_pos + num_encoder_tokens) + if start_pos >= num_computed_tokens + num_scheduled_tokens: + # The encoder output is not needed in this step. + break + if start_pos + num_encoder_tokens <= num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + continue + + start_idx = max(num_computed_tokens - start_pos, 0) + end_idx = min( + num_computed_tokens - start_pos + num_scheduled_tokens, + num_encoder_tokens) + assert start_idx < end_idx + assert req_id in self.encoder_cache + assert i in self.encoder_cache[req_id] + encoder_output = self.encoder_cache[req_id][i] + encoder_outputs.append(encoder_output[start_idx:end_idx]) + return encoder_outputs + @torch.no_grad() def execute_model( self, @@ -586,6 +697,10 @@ def execute_model( # Update cached state self._update_states(scheduler_output) + if self.is_multimodal_model: + # Run the multimodal encoder if any. + self._execute_encoder(scheduler_output) + # If necessary, swap decodes/prompts to have all decodes on the start ensure_decodes_first(self.input_batch) @@ -616,13 +731,39 @@ def execute_model( num_scheduled_tokens) is_first = False + if self.is_multimodal_model: + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + input_ids = prompt_data.input_tokens + # NOTE(mgoin): Once we run prompts together, this function + # _gather_encoder_outputs should be called at _execute_encoder + encoder_outputs = self._gather_req_encoder_outputs( + req_id, scheduler_output) + if encoder_outputs: + inputs_embeds = self.model.get_input_embeddings( + input_ids, encoder_outputs) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + input_ids = None + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the CUDA graph. + input_ids = prompt_data.input_tokens + inputs_embeds = None + # Run forward pass with set_forward_context(prompt_data.attn_metadata, self.vllm_config): assert self.model is not None - selected_token_ids = self.model(prompt_data.input_tokens, - prompt_data.input_positions, - self.kv_caches) + selected_token_ids = self.model( + input_ids=input_ids, + positions=prompt_data.input_positions, + kv_caches=self.kv_caches, + inputs_embeds=inputs_embeds, + ) # In parallel to TPU execution, prepare the next iteration if i < num_prompts - 1: @@ -654,13 +795,20 @@ def execute_model( if decode_data is None: decode_data = self._prepare_decode(pd_info.decode_req_ids) + # TODO: Once we combine prompts and decodes, we should use the same + # choice of input_ids or inputs_embeds throughout execute_model. + inputs_embeds = None + # Run forward pass with set_forward_context(decode_data.attn_metadata, self.vllm_config): assert self.model is not None - selected_token_ids = self.model(decode_data.input_tokens, - decode_data.input_positions, - self.kv_caches) + selected_token_ids = self.model( + input_ids=decode_data.input_tokens, + positions=decode_data.input_positions, + kv_caches=self.kv_caches, + inputs_embeds=inputs_embeds, + ) # Transfer sampled tokens from TPU to CPU decode_token_ids_cpu = selected_token_ids.cpu() @@ -727,7 +875,7 @@ def load_model(self) -> None: fullgraph=True, dynamic=False) - def dummy_run( + def _dummy_run( self, kv_caches, num_tokens: int, @@ -740,9 +888,19 @@ def dummy_run( exec_mode = ExecutionMode(exec_mode) if exec_mode.is_prefill(): seq_len = (seq_len + 15) // 16 * 16 - token_ids = torch.zeros((num_tokens, seq_len), - dtype=torch.int32, - device=self.device) + + if self.is_multimodal_model: + input_ids = None + inputs_embeds = torch.zeros( + (num_tokens, seq_len, self.hidden_size), + dtype=self.dtype, + device=self.device) + else: + input_ids = torch.zeros((num_tokens, seq_len), + dtype=torch.int32, + device=self.device) + inputs_embeds = None + position_ids = torch.zeros((num_tokens, seq_len), dtype=torch.int32, device=self.device) @@ -786,10 +944,14 @@ def dummy_run( effective_query_lens=effective_query_lens, ) else: + # Decode assert seq_len == 1 - token_ids = torch.zeros((num_tokens, seq_len), + input_ids = torch.zeros((num_tokens, seq_len), dtype=torch.int32, device=self.device) + # TODO: Once we combine prompts and decodes, we should use the same + # choice of input_ids or inputs_embeds throughout _dummy_run. + inputs_embeds = None position_ids = torch.zeros((num_tokens, seq_len), dtype=torch.int32, device=self.device) @@ -822,13 +984,16 @@ def dummy_run( # in the first run, but can be skipped afterwards as we cache the XLA # graphs in the disk (VLLM_XLA_CACHE_PATH). if exec_mode.is_prefill(): - # Prefll - torch._dynamo.mark_dynamic(token_ids, 1) + # Prefill + if self.is_multimodal_model: + torch._dynamo.mark_dynamic(inputs_embeds, 1) + else: + torch._dynamo.mark_dynamic(input_ids, 1) torch._dynamo.mark_dynamic(position_ids, 1) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) else: # Decode - torch._dynamo.mark_dynamic(token_ids, 0) + torch._dynamo.mark_dynamic(input_ids, 0) torch._dynamo.mark_dynamic(position_ids, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) @@ -836,7 +1001,12 @@ def dummy_run( with set_forward_context(attn_metadata, self.vllm_config, 0): assert self.model is not None - self.model(token_ids, position_ids, kv_caches) + self.model( + input_ids=input_ids, + positions=position_ids, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds, + ) def capture_model(self) -> None: """Compile the model.""" @@ -848,10 +1018,10 @@ def capture_model(self) -> None: for batch_size in [1]: seq_len = 16 while seq_len <= self.model_config.max_model_len: - self.dummy_run(self.kv_caches, - batch_size, - seq_len, - exec_mode=ExecutionMode.PREFILL) + self._dummy_run(self.kv_caches, + batch_size, + seq_len, + exec_mode=ExecutionMode.PREFILL) xm.wait_device_ops() logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len) @@ -872,10 +1042,10 @@ def capture_model(self) -> None: for batch_size in [1]: seq_len = 16 while seq_len <= self.model_config.max_model_len: - self.dummy_run(self.kv_caches, - batch_size, - seq_len, - exec_mode=ExecutionMode.PREFIX_PREFILL) + self._dummy_run(self.kv_caches, + batch_size, + seq_len, + exec_mode=ExecutionMode.PREFIX_PREFILL) xm.wait_device_ops() logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len) @@ -896,10 +1066,10 @@ def capture_model(self) -> None: seq_len = 1 batch_size = 8 # Must be in sync with _get_padded_batch_size() while True: - self.dummy_run(self.kv_caches, - batch_size, - seq_len, - exec_mode=ExecutionMode.DECODE) + self._dummy_run(self.kv_caches, + batch_size, + seq_len, + exec_mode=ExecutionMode.DECODE) xm.wait_device_ops() logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len) @@ -958,21 +1128,20 @@ def __init__(self, model: nn.Module): def forward( self, - token_ids: torch.Tensor, - position_ids: torch.Tensor, + input_ids: torch.Tensor, + positions: torch.Tensor, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Executes the forward pass of the model and samples the next token. Args: - token_ids: The input token IDs of shape [batch_size, seq_len]. - position_ids: The input position IDs of shape [batch_size, seq_len]. - input_lens: The actual input lengths of shape [batch_size]. - t: The sampling temperature of shape [batch_size]. - p: The top-p probability of shape [batch_size]. - num_samples: Number of samples to draw from each logits vector. + input_ids: The input token IDs of shape [batch_size, seq_len]. + positions: The input position IDs of shape [batch_size, seq_len]. kv_caches: The key and value caches. They can be None during the memory profiling at initialization. + inputs_embeds: The input embeddings of shape [batch_size, seq_len, + hidden_size]. It is used for multimodal models. """ # Skip this in memory profiling at initialization. if kv_caches[0][0].numel() > 0: @@ -997,7 +1166,11 @@ def forward( attn_metadata.slot_mapping = slot_mapping assert self.model is not None - hidden_states = self.model(token_ids, position_ids) + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + ) hidden_states = hidden_states.flatten(0, 1) logits = self.model.compute_logits(hidden_states, None) @@ -1007,6 +1180,12 @@ def forward( argmax_token_ids = argmax_token_ids.squeeze(dim=-1) return argmax_token_ids + def get_multimodal_embeddings(self, *args, **kwargs): + return self.model.get_multimodal_embeddings(*args, **kwargs) + + def get_input_embeddings(self, *args, **kwargs): + return self.model.get_input_embeddings(*args, **kwargs) + def swap_positions(b: InputBatch, id_1, id_2): assert id_1 != id_2 diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index c236f263eddb..c5fc00a8daa7 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -124,7 +124,7 @@ def determine_available_memory(self) -> int: self.vllm_config.compilation_config.static_forward_context, runner_kv_caches) - self.model_runner.dummy_run( + self.model_runner._dummy_run( runner_kv_caches, num_tokens=1, seq_len=self.scheduler_config.max_num_batched_tokens,