From 6071a4b86672b9ed02cc99e1ae314934467c8ae4 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 18 Jan 2025 17:24:15 -0800 Subject: [PATCH 01/75] skeleton --- vllm/v1/core/scheduler.py | 39 +++++++++++----- vllm/v1/engine/core.py | 9 ++++ vllm/v1/outputs.py | 10 ++-- vllm/v1/request.py | 16 +++++++ vllm/v1/sample/metadata.py | 2 + vllm/v1/sample/rejection_sampler.py | 41 ++++++++++++++++ vllm/v1/sample/sampler.py | 8 ++++ vllm/v1/worker/gpu_input_batch.py | 8 +++- vllm/v1/worker/gpu_model_runner.py | 72 +++++++++++++++++++++-------- 9 files changed, 168 insertions(+), 37 deletions(-) create mode 100644 vllm/v1/sample/rejection_sampler.py diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index baaf3329dc79..33c32603cc0e 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -84,8 +84,8 @@ def __init__( def schedule(self) -> "SchedulerOutput": # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. - # Each request just has the num_computed_tokens and num_tokens, - # which is equal to len(prompt_token_ids) + len(output_token_ids). + # Each request just has the num_computed_tokens and num_tokens. + # num_tokens = len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids) # At each step, the scheduler tries to assign tokens to the requests # so that each request's num_computed_tokens can catch up its # num_tokens. This is general enough to cover chunked prefills, @@ -101,6 +101,9 @@ def schedule(self) -> "SchedulerOutput": token_budget = self.max_num_scheduled_tokens # Encoder-related. scheduled_encoder_inputs: Dict[str, List[int]] = {} + # Spec Decode-related. + spec_decode = False + scheduled_spec_decode_tokens: Dict[str, List[int]] = {} encoder_budget = self.max_num_encoder_input_tokens # First, schedule the RUNNING requests. @@ -116,7 +119,7 @@ def schedule(self) -> "SchedulerOutput": assert not has_partial_request assert token_budget > 0 request = self.running[req_index] - num_new_tokens = request.num_tokens - request.num_computed_tokens + num_new_tokens = request.num_tokens_with_spec - request.num_computed_tokens num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 @@ -172,6 +175,10 @@ def schedule(self) -> "SchedulerOutput": for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) encoder_budget = new_encoder_budget + + if request.spec_token_ids: + spec_decode = True + scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids # Next, schedule the WAITING requests. if not preempted_reqs: @@ -295,6 +302,8 @@ def schedule(self) -> "SchedulerOutput": num_scheduled_tokens=num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens, scheduled_encoder_inputs=scheduled_encoder_inputs, + use_spec_decode=spec_decode, + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, num_common_prefix_blocks=num_common_prefix_blocks, preempted_req_ids=preempted_req_ids, # finished_req_ids is an existing state in the scheduler, @@ -399,12 +408,16 @@ def update_from_output( ) -> List[EngineCoreOutput]: # NOTE(woosuk): This method doesn't consider speculative decoding. sampled_token_ids = model_runner_output.sampled_token_ids - num_scheduled_tokens = scheduler_output.num_scheduled_tokens + # num_scheduled_tokens = scheduler_output.num_scheduled_tokens + new_running: List[Request] = [] engine_core_outputs: List[EngineCoreOutput] = [] for request in self.running: req_id = request.request_id - request.num_computed_tokens += num_scheduled_tokens[req_id] + req_index = model_runner_output.req_id_to_index[req_id] + token_ids = sampled_token_ids[req_index] + request.num_computed_tokens += len(token_ids) + # When the request's num_computed_tokens catches up its num_tokens, # the request generates output tokens. Otherwise, we ignore the # sampler output for the request. @@ -420,13 +433,10 @@ def update_from_output( # in the decoder's KV cache. self.encoder_cache_manager.free(request, input_id) - if request.num_computed_tokens == request.num_tokens: - req_index = model_runner_output.req_id_to_index[req_id] - # NOTE(woosuk): Currently, we assume that each request - # generates at most one token at each step. - token_id = sampled_token_ids[req_index] - request.append_output_token_ids(token_id) - num_new_tokens = 1 + if request.num_computed_tokens >= request.num_tokens: + request.clear_spec_tokens() + request.append_output_token_ids(token_ids) + num_new_tokens = len(token_ids) # TODO: Update the KV cache manager for prefix caching. # Check for stop and update request state. @@ -450,6 +460,9 @@ def update_from_output( self.running = new_running return engine_core_outputs + # TODO: the following logic does not consider + # when multiple tokens are generated in a + # single forward pass def _check_stop(self, request: Request) -> bool: if (request.num_tokens >= self.max_model_len or request.num_output_tokens >= request.max_tokens): @@ -603,6 +616,8 @@ class SchedulerOutput: num_scheduled_tokens: Dict[str, int] total_num_scheduled_tokens: int scheduled_encoder_inputs: Dict[str, List[int]] + use_spec_decode: bool + scheduled_spec_decode_tokens: Dict[str, List[int]] num_common_prefix_blocks: int preempted_req_ids: Set[str] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 975ce11fe8af..d4daa46e1e15 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -120,6 +120,15 @@ def step(self) -> List[EngineCoreOutput]: if not self.scheduler.has_unfinished_requests(): return [] + logger.info("Running EngineCore step.") + # Append tokens to requests directly + # to mimic ngram proposal. + # Only change requests in the running queue. + # We don't do spec decode in the prefill phase for now. + # We don't handle prefill kv cache for now. + for req in self.scheduler.running: + req.append_spec_token_ids([1] * 5) + scheduler_output = self.scheduler.schedule() output = self.model_executor.execute_model(scheduler_output) engine_core_outputs = self.scheduler.update_from_output( diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index acc3a944e21b..b7d8b46d45d6 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -7,8 +7,9 @@ @dataclass class SamplerOutput: - # [num_reqs] - sampled_token_ids: List[int] + # num_reqs x [num_generated_tokens] + # num_generated_tokens might be different for each request. + sampled_token_ids: List[List[int]] # [num_reqs, max_num_logprobs + 1] logprob_token_ids: Optional[torch.Tensor] @@ -30,8 +31,9 @@ class ModelRunnerOutput: # req_id -> index req_id_to_index: Dict[str, int] - # [num_reqs] - sampled_token_ids: List[int] + # num_reqs x [num_generated_tokens] + # num_generated_tokens might be different for each request. + sampled_token_ids: List[List[int]] # [num_reqs, max_num_logprobs + 1] logprob_token_ids_cpu: Optional[torch.Tensor] diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 45450165eaef..d686ea66ac83 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -49,6 +49,7 @@ def __init__( self.num_prompt_tokens = len(self.prompt_token_ids) self._output_token_ids: List[int] = [] self._all_token_ids: List[int] = self.prompt_token_ids.copy() + self._spec_token_ids: List[int] = [] self.num_computed_tokens = 0 # Multi-modal related @@ -99,10 +100,25 @@ def append_output_token_ids( token_ids = [token_ids] self._output_token_ids.extend(token_ids) self._all_token_ids.extend(token_ids) + + def append_spec_token_ids( + self, + token_ids: Union[int, List[int]], + ) -> None: + if isinstance(token_ids, int): + token_ids = [token_ids] + self._spec_token_ids.extend(token_ids) + + def clear_spec_tokens(self) -> None: + self._spec_token_ids = [] @property def num_tokens(self) -> int: return len(self._all_token_ids) + + @property + def num_tokens_with_spec(self) -> int: + return len(self._all_token_ids) + len(self._spec_token_ids) @property def num_output_tokens(self) -> int: diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index d60f7eb5d76f..8cb5c6823187 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -10,6 +10,8 @@ class SamplingMetadata: temperature: torch.Tensor all_greedy: bool all_random: bool + rejection_sampling: bool + spec_token_ids: List[List[int]] top_p: torch.Tensor top_k: torch.Tensor diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py new file mode 100644 index 000000000000..b7727347167c --- /dev/null +++ b/vllm/v1/sample/rejection_sampler.py @@ -0,0 +1,41 @@ +import torch +from vllm.v1.outputs import SamplerOutput +from vllm.v1.sample.metadata import SamplingMetadata + +class RejectionSampler(nn.Module): + def sample(self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata) -> SamplerOutput: + # num_reqs x [num_specuated_tokens] + spec_token_ids = sampling_metadata.spec_token_ids + # only argmax is supported for now + output_token_ids_cpu = logits.argmax(dim=-1).view(-1).tolist() + + sampled_token_ids = [] + # Stop at the first mismatch place. + # spec_tokens: [1, 2, 3] + # output_tokens: [1, 2, 4, 5] + # sampled_tokens: [1, 2, 4] + output_token_start_idx = 0 + for spec_tokens in spec_token_ids: + num_spec_tokens = len(spec_tokens) + output_tokens = output_token_ids_cpu[output_token_start_idx: output_token_start_idx + 1 + num_spec_tokens] + i = 0 + while i < len(spec_tokens): + if spec_tokens[i] != output_tokens[i]: + break + i += 1 + # +1 to include the bonus token. + i += 1 + output_tokens = output_tokens[:i] + sampled_token_ids.append(output_tokens) + output_token_start_idx += num_spec_tokens + 1 + assert output_token_start_idx == len(output_token_ids_cpu) + + return SamplerOutput(sampled_token_ids=sampled_token_ids, + logprob_token_ids=None, + logprobs=None, + prompt_logprob_token_ids=None, + prompt_logprobs=None) + + \ No newline at end of file diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 7cd42ca211a2..8c92cdad2992 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -9,6 +9,7 @@ from vllm.v1.sample.ops.penalties import (apply_all_penalties, apply_min_token_penalties) from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler +from vllm.v1.sample.rejection_sampler import RejectionSampler _SAMPLING_EPS = 1e-5 @@ -18,12 +19,19 @@ class Sampler(nn.Module): def __init__(self): super().__init__() self.topk_topp_sampler = TopKTopPSampler() + self.rejection_sampler = RejectionSampler() def forward( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: + if sampling_metadata.rejection_sampling: + return self.rejection_sampler.sample( + logits, + sampling_metadata, + ) + needs_logprobs = sampling_metadata.max_num_logprobs > 0 if needs_logprobs: # NOTE(woosuk): Use the original logits (before any penalties or diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 40494e64b22f..5e12cc1afaef 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -320,6 +320,8 @@ def make_sampling_metadata( self, req_id_output_token_ids: Dict[str, List[int]], skip_copy: bool = False, + rejection_sampling: bool = False, + req_id_to_spec_token_ids: Optional[Dict[str, List[int]]] = None, ) -> SamplingMetadata: if not skip_copy: self.temperature[:self.num_reqs].copy_( @@ -347,7 +349,7 @@ def make_sampling_metadata( self.prompt_token_ids = self._make_prompt_token_ids_tensor() output_token_ids: List[List[int]] = [] - + spec_token_ids: List[List[int]] = [] for req_id in self.req_ids[:self.num_reqs]: assert req_id is not None # Currently we create a tensor for output_token_ids from scratch @@ -358,11 +360,14 @@ def make_sampling_metadata( # TODO - Replace this with incremental update to output token # statistics. output_token_ids.append(req_id_output_token_ids[req_id]) + if req_id_to_spec_token_ids is not None: + spec_token_ids.append(req_id_to_spec_token_ids[req_id]) return SamplingMetadata( temperature=self.temperature[:self.num_reqs], all_greedy=self.all_greedy, all_random=self.all_random, + rejection_sampling=rejection_sampling, top_p=self.top_p[:self.num_reqs], top_k=self.top_k[:self.num_reqs], no_top_p=self.no_top_p, @@ -374,6 +379,7 @@ def make_sampling_metadata( presence_penalties=self.presence_penalties[:self.num_reqs], repetition_penalties=self.repetition_penalties[:self.num_reqs], output_token_ids=output_token_ids, + spec_token_ids=spec_token_ids, min_tokens=self.min_tokens[:self.num_reqs], stop_token_ids=self.stop_token_ids[:self.num_reqs], no_penalties=self.no_penalties, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a1d4f9b13578..ac2a3453c9e8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,6 +1,6 @@ import gc import time -from typing import TYPE_CHECKING, Dict, List, Tuple, cast +from typing import TYPE_CHECKING, Dict, List, Tuple, cast, Optional import numpy as np import torch @@ -264,7 +264,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: if removed_req_indices: self.input_batch.condense(removed_req_indices) - def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): + def _prepare_inputs(self, scheduler_output: "SchedulerOutput") \ + -> Tuple[FlashAttentionMetadata, torch.Tensor]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs @@ -278,7 +279,9 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # TODO: The Python loop can be slow. Optimize. num_scheduled_tokens = [] max_num_scheduled_tokens = 0 - for req_id in self.input_batch.req_ids[:num_reqs]: + for i, req_id in enumerate(self.input_batch.req_ids): + if i == num_reqs: + break assert req_id is not None num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens.append(num_tokens) @@ -309,6 +312,24 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # where M is the max_model_len. token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) + + # Add spec decode tokens to input_batch.token_ids_cpu. + # Get spec decode logits indices. + spec_query_end_loc = 0 + spec_decode_logits_indices = [] + for i, req_id in enumerate(self.input_batch.req_id): + if i == num_reqs: + break + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_compute_tokens = self.input_batch.num_computed_tokens_cpu[i] + spec_query_end_loc += num_scheduled_tokens + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens[req_id] + for j, spec_token_id in enumerate(spec_token_ids): + # +1 here because the input for verification is [last_output_token_id] + spec_token_ids + self.input_batch.token_ids_cpu[i, num_compute_tokens + 1 + j] = spec_token_id + # -1 here because the input for verification is [last_output_token_id] + spec_token_ids + spec_decode_logits_indices.extend(range(spec_query_end_loc - len(spec_token_ids) - 1, spec_query_end_loc)) + # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. @@ -454,12 +475,16 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): cu_prefix_kv_lens=cu_prefix_kv_lens, cu_suffix_kv_lens=cu_suffix_kv_lens, ) - # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial - # request in the batch. While we should not sample any token from this - # partial request, we do so for simplicity. We will ignore the sampled - # token from the partial request. - # TODO: Support prompt logprobs. - logits_indices = query_start_loc[1:] - 1 + + if scheduler_output.use_spec_decode: + logits_indices = torch.tensor(spec_decode_logits_indices, device=self.device) + else: + # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial + # request in the batch. While we should not sample any token from this + # partial request, we do so for simplicity. We will ignore the sampled + # token from the partial request. + # TODO: Support prompt logprobs. + logits_indices = query_start_loc[1:] - 1 return attn_metadata, logits_indices def _prepare_sampling( @@ -479,7 +504,10 @@ def _prepare_sampling( for req_id, req in self.requests.items()} sampling_metadata = self.input_batch.make_sampling_metadata( - req_id_output_token_ids, skip_copy) + req_id_output_token_ids, + skip_copy, + scheduler_output.use_spec_decode, + scheduler_output.scheduled_spec_decode_tokens) return sampling_metadata def _execute_encoder(self, scheduler_output: "SchedulerOutput"): @@ -618,6 +646,7 @@ def execute_model( hidden_states = hidden_states[:num_scheduled_tokens] hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(hidden_states, None) + logger.info("logits: %s", str(logits.shape)) # Sample the next token and get logprobs if needed. sampling_metadata = self._prepare_sampling(scheduler_output) @@ -630,18 +659,21 @@ def execute_model( # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. num_reqs = self.input_batch.num_reqs - for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + for i, req_id in enumerate(self.input_batch.req_ids): + if i == num_reqs: + break assert req_id is not None req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - assert seq_len <= req_state.num_tokens - if seq_len == req_state.num_tokens: - # Append the sampled token to the output token ids. - token_id = sampled_token_ids[i] - self.input_batch.token_ids_cpu[i, seq_len] = token_id - self.input_batch.num_tokens[i] += 1 - req_state.output_token_ids.append(token_id) + seq_len = req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id] + # assert seq_len <= req_state.num_tokens + if seq_len >= req_state.num_tokens: + # We don't rewind the generator state for requests now + # because spec decode only supports greedy decoding for now. + token_ids = sampled_token_ids[i] + for j, token_id in enumerate(token_ids): + self.input_batch.token_ids_cpu[i, req_state.num_computed_tokens + j] = token_id + self.input_batch.num_tokens[i] += len(token_ids) + req_state.output_token_ids.extend(token_ids) else: # Ignore the sampled token from the partial request. # Rewind the generator state as if the token was not sampled. From be798e7bbe1d305971c47f831e7a93c544ed3abf Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 18 Jan 2025 23:00:21 -0800 Subject: [PATCH 02/75] runnable but incorrect --- tests/v1/sample/test_sampler.py | 2 ++ tests/v1/worker/test_gpu_input_batch.py | 2 ++ vllm/v1/core/kv_cache_manager.py | 8 ++++++-- vllm/v1/request.py | 4 ++++ vllm/v1/sample/rejection_sampler.py | 5 +++++ vllm/v1/sample/sampler.py | 2 +- vllm/v1/worker/gpu_input_batch.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 5 ++++- 8 files changed, 25 insertions(+), 5 deletions(-) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index 5ebf72927cfd..a9b922db5f98 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -63,6 +63,7 @@ def _create_default_sampling_metadata( temperature=torch.full((batch_size, ), 0.0), all_greedy=True, all_random=False, + rejection_sampling=False, top_p=torch.empty(batch_size, ), top_k=torch.empty(batch_size, ), no_top_p=True, @@ -72,6 +73,7 @@ def _create_default_sampling_metadata( prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, vocab_size, device), output_token_ids=output_token_ids, + spec_token_ids=[], frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device), presence_penalties=_create_penalty_tensor(batch_size, 0.0, device), repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device), diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 694ce81ff6e2..85472f4aa812 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -85,6 +85,7 @@ def _construct_expected_sampling_metadata( temperature=torch.tensor(temperature, dtype=torch.float, device=device), all_greedy=False, all_random=True, + rejection_sampling=False, top_p=torch.tensor(top_p, dtype=torch.float, device=device), top_k=torch.tensor(top_k, dtype=torch.int, device=device), no_top_p=all(x == 1.0 for x in top_p), @@ -107,6 +108,7 @@ def _construct_expected_sampling_metadata( repetition_penalties, dtype=torch.float, device=device), output_token_ids=output_token_ids, + spec_token_ids=[], min_tokens=min_tokens, stop_token_ids=stop_token_ids, no_penalties=(all(x ==0 for x in presence_penalties) and \ diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 1cbff1e2d767..8eda21eee5ba 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -162,8 +162,12 @@ def append_slots( # TODO(rickyx): When supporting speculative decoding, we will need to # differentiate between them so that we can know how many blocks are # full after appending the actual tokens. - num_full_blocks_after_append = (request.num_computed_tokens + - num_tokens) // self.block_size + + # Does not include speculative tokens. + # FIXME: The logic is not correct because + # we never count speculative tokens that are accepted. + num_cached_tokens = request.num_computed_tokens + num_tokens - len(request.spec_token_ids) + num_full_blocks_after_append = num_cached_tokens // self.block_size assert num_full_blocks_after_append <= len(req_blocks) new_full_blocks = req_blocks[ diff --git a/vllm/v1/request.py b/vllm/v1/request.py index d686ea66ac83..0df43398cce2 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -111,6 +111,10 @@ def append_spec_token_ids( def clear_spec_tokens(self) -> None: self._spec_token_ids = [] + + @property + def spec_token_ids(self) -> ConstantList[int]: + return ConstantList(self._spec_token_ids) @property def num_tokens(self) -> int: diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index b7727347167c..7cb807173db1 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -1,7 +1,10 @@ import torch +import torch.nn as nn from vllm.v1.outputs import SamplerOutput +from vllm.logger import init_logger from vllm.v1.sample.metadata import SamplingMetadata +logger = init_logger(__name__) class RejectionSampler(nn.Module): def sample(self, logits: torch.Tensor, @@ -30,6 +33,8 @@ def sample(self, output_tokens = output_tokens[:i] sampled_token_ids.append(output_tokens) output_token_start_idx += num_spec_tokens + 1 + print("Proposed token ids", spec_tokens._x) + print("Output token ids", output_tokens) assert output_token_start_idx == len(output_token_ids_cpu) return SamplerOutput(sampled_token_ids=sampled_token_ids, diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 8c92cdad2992..9cb913aa4425 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -60,7 +60,7 @@ def forward( # NOTE: CPU-GPU synchronization happens here. sampler_output = SamplerOutput( - sampled_token_ids=sampled.tolist(), + sampled_token_ids=[[x] for x in sampled.tolist()], logprob_token_ids=topk_indices, logprobs=topk_logprobs, prompt_logprob_token_ids=None, diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 5e12cc1afaef..d15a0435cda8 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -360,7 +360,7 @@ def make_sampling_metadata( # TODO - Replace this with incremental update to output token # statistics. output_token_ids.append(req_id_output_token_ids[req_id]) - if req_id_to_spec_token_ids is not None: + if rejection_sampling: spec_token_ids.append(req_id_to_spec_token_ids[req_id]) return SamplingMetadata( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ac2a3453c9e8..594b39315ece 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -317,9 +317,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput") \ # Get spec decode logits indices. spec_query_end_loc = 0 spec_decode_logits_indices = [] - for i, req_id in enumerate(self.input_batch.req_id): + for i, req_id in enumerate(self.input_batch.req_ids): if i == num_reqs: break + if req_id not in scheduler_output.scheduled_spec_decode_tokens: + continue num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] num_compute_tokens = self.input_batch.num_computed_tokens_cpu[i] spec_query_end_loc += num_scheduled_tokens @@ -370,6 +372,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput") \ # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) + print("input_ids", self.input_ids[:total_num_scheduled_tokens]) self.positions[:total_num_scheduled_tokens].copy_( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to( From f0976dd28fc67f23a8e6b3d24cd7957899c35d1f Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sun, 19 Jan 2025 06:55:26 -0800 Subject: [PATCH 03/75] fix --- vllm/v1/core/scheduler.py | 5 +++-- vllm/v1/engine/core.py | 14 +++++++------- vllm/v1/worker/gpu_model_runner.py | 30 +++++++++++++++--------------- 3 files changed, 25 insertions(+), 24 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 33c32603cc0e..3221aa7a83fc 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -408,7 +408,7 @@ def update_from_output( ) -> List[EngineCoreOutput]: # NOTE(woosuk): This method doesn't consider speculative decoding. sampled_token_ids = model_runner_output.sampled_token_ids - # num_scheduled_tokens = scheduler_output.num_scheduled_tokens + num_scheduled_tokens = scheduler_output.num_scheduled_tokens new_running: List[Request] = [] engine_core_outputs: List[EngineCoreOutput] = [] @@ -416,7 +416,8 @@ def update_from_output( req_id = request.request_id req_index = model_runner_output.req_id_to_index[req_id] token_ids = sampled_token_ids[req_index] - request.num_computed_tokens += len(token_ids) + # FIXME: have a cleaner way to handle this + request.num_computed_tokens += num_scheduled_tokens[req_id] - (len(request.spec_token_ids) + 1 - len(token_ids)) # When the request's num_computed_tokens catches up its num_tokens, # the request generates output tokens. Otherwise, we ignore the diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index d4daa46e1e15..18b985191c5a 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -121,13 +121,13 @@ def step(self) -> List[EngineCoreOutput]: return [] logger.info("Running EngineCore step.") - # Append tokens to requests directly - # to mimic ngram proposal. - # Only change requests in the running queue. - # We don't do spec decode in the prefill phase for now. - # We don't handle prefill kv cache for now. - for req in self.scheduler.running: - req.append_spec_token_ids([1] * 5) + # # Append tokens to requests directly + # # to mimic ngram proposal. + # # Only change requests in the running queue. + # # We don't do spec decode in the prefill phase for now. + # # We don't handle prefill kv cache for now. + # for req in self.scheduler.running: + # req.append_spec_token_ids([1] * 5) scheduler_output = self.scheduler.schedule() output = self.model_executor.execute_model(scheduler_output) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 594b39315ece..6c736ac6783e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -317,20 +317,20 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput") \ # Get spec decode logits indices. spec_query_end_loc = 0 spec_decode_logits_indices = [] - for i, req_id in enumerate(self.input_batch.req_ids): - if i == num_reqs: - break - if req_id not in scheduler_output.scheduled_spec_decode_tokens: - continue - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] - num_compute_tokens = self.input_batch.num_computed_tokens_cpu[i] - spec_query_end_loc += num_scheduled_tokens - spec_token_ids = scheduler_output.scheduled_spec_decode_tokens[req_id] - for j, spec_token_id in enumerate(spec_token_ids): - # +1 here because the input for verification is [last_output_token_id] + spec_token_ids - self.input_batch.token_ids_cpu[i, num_compute_tokens + 1 + j] = spec_token_id - # -1 here because the input for verification is [last_output_token_id] + spec_token_ids - spec_decode_logits_indices.extend(range(spec_query_end_loc - len(spec_token_ids) - 1, spec_query_end_loc)) + # for i, req_id in enumerate(self.input_batch.req_ids): + # if i == num_reqs: + # break + # if req_id not in scheduler_output.scheduled_spec_decode_tokens: + # continue + # num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] + # num_compute_tokens = self.input_batch.num_computed_tokens_cpu[i] + # spec_query_end_loc += num_scheduled_tokens + # spec_token_ids = scheduler_output.scheduled_spec_decode_tokens[req_id] + # for j, spec_token_id in enumerate(spec_token_ids): + # # +1 here because the input for verification is [last_output_token_id] + spec_token_ids + # self.input_batch.token_ids_cpu[i, num_compute_tokens + 1 + j] = spec_token_id + # # -1 here because the input for verification is [last_output_token_id] + spec_token_ids + # spec_decode_logits_indices.extend(range(spec_query_end_loc - len(spec_token_ids) - 1, spec_query_end_loc)) # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large @@ -674,7 +674,7 @@ def execute_model( # because spec decode only supports greedy decoding for now. token_ids = sampled_token_ids[i] for j, token_id in enumerate(token_ids): - self.input_batch.token_ids_cpu[i, req_state.num_computed_tokens + j] = token_id + self.input_batch.token_ids_cpu[i, req_state.num_computed_tokens + 1 + j] = token_id self.input_batch.num_tokens[i] += len(token_ids) req_state.output_token_ids.extend(token_ids) else: From 6039933b774d25fff5ade2157bcfd91c9af2aeed Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sun, 19 Jan 2025 07:19:09 -0800 Subject: [PATCH 04/75] pass for simple non spec case --- vllm/v1/engine/core.py | 14 ++++++------- vllm/v1/worker/gpu_model_runner.py | 33 ++++++++++++++++-------------- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 18b985191c5a..d4daa46e1e15 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -121,13 +121,13 @@ def step(self) -> List[EngineCoreOutput]: return [] logger.info("Running EngineCore step.") - # # Append tokens to requests directly - # # to mimic ngram proposal. - # # Only change requests in the running queue. - # # We don't do spec decode in the prefill phase for now. - # # We don't handle prefill kv cache for now. - # for req in self.scheduler.running: - # req.append_spec_token_ids([1] * 5) + # Append tokens to requests directly + # to mimic ngram proposal. + # Only change requests in the running queue. + # We don't do spec decode in the prefill phase for now. + # We don't handle prefill kv cache for now. + for req in self.scheduler.running: + req.append_spec_token_ids([1] * 5) scheduler_output = self.scheduler.schedule() output = self.model_executor.execute_model(scheduler_output) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6c736ac6783e..8e551e88ef76 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -317,20 +317,20 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput") \ # Get spec decode logits indices. spec_query_end_loc = 0 spec_decode_logits_indices = [] - # for i, req_id in enumerate(self.input_batch.req_ids): - # if i == num_reqs: - # break - # if req_id not in scheduler_output.scheduled_spec_decode_tokens: - # continue - # num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] - # num_compute_tokens = self.input_batch.num_computed_tokens_cpu[i] - # spec_query_end_loc += num_scheduled_tokens - # spec_token_ids = scheduler_output.scheduled_spec_decode_tokens[req_id] - # for j, spec_token_id in enumerate(spec_token_ids): - # # +1 here because the input for verification is [last_output_token_id] + spec_token_ids - # self.input_batch.token_ids_cpu[i, num_compute_tokens + 1 + j] = spec_token_id - # # -1 here because the input for verification is [last_output_token_id] + spec_token_ids - # spec_decode_logits_indices.extend(range(spec_query_end_loc - len(spec_token_ids) - 1, spec_query_end_loc)) + for i, req_id in enumerate(self.input_batch.req_ids): + if i == num_reqs: + break + if req_id not in scheduler_output.scheduled_spec_decode_tokens: + continue + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_compute_tokens = self.input_batch.num_computed_tokens_cpu[i] + spec_query_end_loc += num_scheduled_tokens + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens[req_id] + for j, spec_token_id in enumerate(spec_token_ids): + # +1 here because the input for verification is [last_output_token_id] + spec_token_ids + self.input_batch.token_ids_cpu[i, num_compute_tokens + 1 + j] = spec_token_id + # -1 here because the input for verification is [last_output_token_id] + spec_token_ids + spec_decode_logits_indices.extend(range(spec_query_end_loc - len(spec_token_ids) - 1, spec_query_end_loc)) # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large @@ -659,6 +659,7 @@ def execute_model( ) sampled_token_ids = sampler_output.sampled_token_ids + spec_tokens = scheduler_output.scheduled_spec_decode_tokens # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. num_reqs = self.input_batch.num_reqs @@ -670,11 +671,13 @@ def execute_model( seq_len = req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id] # assert seq_len <= req_state.num_tokens if seq_len >= req_state.num_tokens: + print("output_token_ids", sampled_token_ids[i], req_state.num_computed_tokens) # We don't rewind the generator state for requests now # because spec decode only supports greedy decoding for now. token_ids = sampled_token_ids[i] + spec_token_ids = spec_tokens.get(req_id, []) for j, token_id in enumerate(token_ids): - self.input_batch.token_ids_cpu[i, req_state.num_computed_tokens + 1 + j] = token_id + self.input_batch.token_ids_cpu[i, seq_len - len(spec_token_ids) + j] = token_id self.input_batch.num_tokens[i] += len(token_ids) req_state.output_token_ids.extend(token_ids) else: From 03cd3ddedb450ac3e4ef41af4e534bad7d29aec9 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sun, 19 Jan 2025 14:46:18 -0800 Subject: [PATCH 05/75] pass args and minor variable name bug fix --- vllm/v1/core/scheduler.py | 5 ++++- vllm/v1/engine/core.py | 20 +++++++++++--------- vllm/v1/worker/gpu_model_runner.py | 10 ++++------ 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 3221aa7a83fc..9a8bb0676b81 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -3,7 +3,8 @@ from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set, Tuple, Union) -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.config import(CacheConfig, LoRAConfig, + SchedulerConfig, SpeculativeConfig) from vllm.logger import init_logger from vllm.multimodal import MultiModalKwargs from vllm.multimodal.base import PlaceholderRange @@ -28,10 +29,12 @@ def __init__( scheduler_config: SchedulerConfig, cache_config: CacheConfig, lora_config: Optional[LoRAConfig], + speculative_config: Optional[SpeculativeConfig] = None, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config self.lora_config = lora_config + self.speculative_config = speculative_config # TODO: Support LoRA. assert lora_config is None, "V1 does not support LoRA yet." diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index d4daa46e1e15..fb481eb38859 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -60,7 +60,8 @@ def __init__( # Setup scheduler. self.scheduler = Scheduler(vllm_config.scheduler_config, vllm_config.cache_config, - vllm_config.lora_config) + vllm_config.lora_config, + vllm_config.speculative_config) self._last_logging_time = time.time() @@ -120,14 +121,15 @@ def step(self) -> List[EngineCoreOutput]: if not self.scheduler.has_unfinished_requests(): return [] - logger.info("Running EngineCore step.") - # Append tokens to requests directly - # to mimic ngram proposal. - # Only change requests in the running queue. - # We don't do spec decode in the prefill phase for now. - # We don't handle prefill kv cache for now. - for req in self.scheduler.running: - req.append_spec_token_ids([1] * 5) + + if self.scheduler.speculative_config and self.scheduler.speculative_config.num_speculative_tokens > 0: + # Append tokens to requests directly + # to mimic ngram proposal. + # Only change requests in the running queue. + # We don't do spec decode in the prefill phase for now. + # We don't handle spec decode kv cache for now. + for req in self.scheduler.running: + req.append_spec_token_ids([1] * self.scheduler.speculative_config.num_speculative_tokens) scheduler_output = self.scheduler.schedule() output = self.model_executor.execute_model(scheduler_output) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8e551e88ef76..15c4c0f0d7c0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -320,12 +320,10 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput") \ for i, req_id in enumerate(self.input_batch.req_ids): if i == num_reqs: break - if req_id not in scheduler_output.scheduled_spec_decode_tokens: - continue - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] + req_num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] num_compute_tokens = self.input_batch.num_computed_tokens_cpu[i] - spec_query_end_loc += num_scheduled_tokens - spec_token_ids = scheduler_output.scheduled_spec_decode_tokens[req_id] + spec_query_end_loc += req_num_scheduled_tokens + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) for j, spec_token_id in enumerate(spec_token_ids): # +1 here because the input for verification is [last_output_token_id] + spec_token_ids self.input_batch.token_ids_cpu[i, num_compute_tokens + 1 + j] = spec_token_id @@ -671,7 +669,7 @@ def execute_model( seq_len = req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id] # assert seq_len <= req_state.num_tokens if seq_len >= req_state.num_tokens: - print("output_token_ids", sampled_token_ids[i], req_state.num_computed_tokens) + print(req_state.req_id, "output_token_ids", sampled_token_ids[i], req_state.num_computed_tokens) # We don't rewind the generator state for requests now # because spec decode only supports greedy decoding for now. token_ids = sampled_token_ids[i] From 26ba690b8fc55cd4d92b2b44594e87bca3048a84 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sun, 19 Jan 2025 14:51:25 -0800 Subject: [PATCH 06/75] minor --- vllm/platforms/cuda.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 3c5350b77834..27d168c5a9a3 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -124,7 +124,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "vllm.worker.multi_step_worker.MultiStepWorker" elif vllm_config.speculative_config: if envs.VLLM_USE_V1: - raise NotImplementedError + pass + # raise NotImplementedError else: parallel_config.worker_cls = \ "vllm.spec_decode.spec_decode_worker.create_spec_worker" From ca5e0dd01aaaf3578a3657fe1853ba9bf1fe83e4 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sun, 19 Jan 2025 15:09:22 -0800 Subject: [PATCH 07/75] minimal example --- examples/offline_inference_spec_decode.py | 25 +++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 examples/offline_inference_spec_decode.py diff --git a/examples/offline_inference_spec_decode.py b/examples/offline_inference_spec_decode.py new file mode 100644 index 000000000000..9b5237344613 --- /dev/null +++ b/examples/offline_inference_spec_decode.py @@ -0,0 +1,25 @@ +from vllm import LLM, SamplingParams + +prompts = [ + "The future of AI is", + "Speculative decoding is a method", +] +sampling_params = SamplingParams(temperature=0.0) + +llm = LLM( + model="meta-llama/Meta-Llama-3-8B-Instruct", + # tensor_parallel_size=1, + # speculative_model="/data/lily/eagle-8b-instruct-model", + # speculative_draft_tensor_parallel_size=1, + speculative_model='[ngram]', + ngram_prompt_lookup_max=5, + ngram_prompt_lookup_min=3, + num_speculative_tokens=3 +) + +outputs = llm.generate(prompts, sampling_params) + +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file From 62012d14f99a518d74401374fe4913ee58ee1a8a Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sun, 19 Jan 2025 19:57:16 -0800 Subject: [PATCH 08/75] minor --- vllm/platforms/cuda.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 79f816730c3f..c345ae51e76b 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -126,8 +126,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "vllm.worker.multi_step_worker.MultiStepWorker" elif vllm_config.speculative_config: if envs.VLLM_USE_V1: - pass - # raise NotImplementedError + parallel_config.worker_cls = \ + "vllm.v1.worker.gpu_worker.Worker" else: parallel_config.worker_cls = \ "vllm.spec_decode.spec_decode_worker.create_spec_worker" From 7bd3f275d3e765aa6656e24afd7aec51545d4cb5 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sun, 19 Jan 2025 22:28:57 -0800 Subject: [PATCH 09/75] format --- vllm/v1/core/kv_cache_manager.py | 5 ++- vllm/v1/core/scheduler.py | 24 +++++++----- vllm/v1/engine/core.py | 23 +++++++----- vllm/v1/request.py | 12 +++--- vllm/v1/sample/rejection_sampler.py | 31 ++++++++-------- vllm/v1/sample/sampler.py | 4 +- vllm/v1/worker/gpu_input_batch.py | 1 + vllm/v1/worker/gpu_model_runner.py | 57 +++++++++++++++++------------ 8 files changed, 89 insertions(+), 68 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 77ecffa234a4..63210bd36e08 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -169,11 +169,12 @@ def append_slots( # TODO(rickyx): When supporting speculative decoding, we will need to # differentiate between them so that we can know how many blocks are # full after appending the actual tokens. - + # Does not include speculative tokens. # FIXME: The logic is not correct because # we never count speculative tokens that are accepted. - num_cached_tokens = request.num_computed_tokens + num_tokens - len(request.spec_token_ids) + num_cached_tokens = request.num_computed_tokens + num_tokens - len( + request.spec_token_ids) num_full_blocks_after_append = num_cached_tokens // self.block_size assert num_full_blocks_after_append <= len(req_blocks) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 502a14f1f3bc..f2626111dfa2 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -3,8 +3,8 @@ from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set, Tuple, Union) -from vllm.config import(CacheConfig, LoRAConfig, ModelConfig, - SchedulerConfig, SpeculativeConfig) +from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig, + SpeculativeConfig) from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, @@ -97,7 +97,8 @@ def schedule(self) -> "SchedulerOutput": # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. # Each request just has the num_computed_tokens and num_tokens. - # num_tokens = len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids) + # num_tokens = len(prompt_token_ids) + len(output_token_ids) + + # len(spec_token_ids). # At each step, the scheduler tries to assign tokens to the requests # so that each request's num_computed_tokens can catch up its # num_tokens. This is general enough to cover chunked prefills, @@ -131,7 +132,8 @@ def schedule(self) -> "SchedulerOutput": assert not has_partial_request assert token_budget > 0 request = self.running[req_index] - num_new_tokens = request.num_tokens_with_spec - request.num_computed_tokens + num_new_tokens = request.num_tokens_with_spec \ + - request.num_computed_tokens num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 @@ -187,10 +189,11 @@ def schedule(self) -> "SchedulerOutput": for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) encoder_budget = new_encoder_budget - + if request.spec_token_ids: spec_decode = True - scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids + scheduled_spec_decode_tokens[ + request.request_id] = request.spec_token_ids # Next, schedule the WAITING requests. if not preempted_reqs: @@ -421,7 +424,7 @@ def update_from_output( # NOTE(woosuk): This method doesn't consider speculative decoding. sampled_token_ids = model_runner_output.sampled_token_ids num_scheduled_tokens = scheduler_output.num_scheduled_tokens - + new_running: List[Request] = [] outputs: List[EngineCoreOutput] = [] for request in self.running: @@ -429,8 +432,9 @@ def update_from_output( req_index = model_runner_output.req_id_to_index[req_id] token_ids = sampled_token_ids[req_index] # FIXME: have a cleaner way to handle this - request.num_computed_tokens += num_scheduled_tokens[req_id] - (len(request.spec_token_ids) + 1 - len(token_ids)) - + request.num_computed_tokens += num_scheduled_tokens[req_id] - ( + len(request.spec_token_ids) + 1 - len(token_ids)) + # When the request's num_computed_tokens catches up its num_tokens, # the request generates output tokens. Otherwise, we ignore the # sampler output for the request. @@ -477,7 +481,7 @@ def update_from_output( ) # TODO: the following logic does not consider - # when multiple tokens are generated in a + # when multiple tokens are generated in a # single forward pass def _check_stop(self, request: Request) -> bool: if (request.num_tokens >= self.max_model_len diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 83e75b60eb99..b598bb6d9a9c 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -55,11 +55,12 @@ def __init__( vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks # Setup scheduler. - self.scheduler = Scheduler(scheduler_config=vllm_config.scheduler_config, - model_config=vllm_config.model_config, - cache_config=vllm_config.cache_config, - lora_config=vllm_config.lora_config, - speculative_config=vllm_config.speculative_config) + self.scheduler = Scheduler( + scheduler_config=vllm_config.scheduler_config, + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + lora_config=vllm_config.lora_config, + speculative_config=vllm_config.speculative_config) self._last_logging_time = time.time() @@ -124,16 +125,18 @@ def step(self) -> EngineCoreOutputs: return EngineCoreOutputs( outputs=[], scheduler_stats=self.scheduler.make_stats()) - - if self.scheduler.speculative_config and self.scheduler.speculative_config.num_speculative_tokens > 0: + if self.scheduler.speculative_config and \ + self.scheduler.speculative_config.num_speculative_tokens > 0: # Append tokens to requests directly - # to mimic ngram proposal. + # to mimic ngram proposal. # Only change requests in the running queue. # We don't do spec decode in the prefill phase for now. # We don't handle spec decode kv cache for now. for req in self.scheduler.running: - req.append_spec_token_ids([1] * self.scheduler.speculative_config.num_speculative_tokens) - + req.append_spec_token_ids( + [1] * + self.scheduler.speculative_config.num_speculative_tokens) + scheduler_output = self.scheduler.schedule() output = self.model_executor.execute_model(scheduler_output) engine_core_outputs = self.scheduler.update_from_output( diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 0df43398cce2..b38d43decb1a 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -100,7 +100,7 @@ def append_output_token_ids( token_ids = [token_ids] self._output_token_ids.extend(token_ids) self._all_token_ids.extend(token_ids) - + def append_spec_token_ids( self, token_ids: Union[int, List[int]], @@ -108,18 +108,18 @@ def append_spec_token_ids( if isinstance(token_ids, int): token_ids = [token_ids] self._spec_token_ids.extend(token_ids) - + def clear_spec_tokens(self) -> None: self._spec_token_ids = [] - + @property - def spec_token_ids(self) -> ConstantList[int]: - return ConstantList(self._spec_token_ids) + def spec_token_ids(self) -> List[int]: + return self._spec_token_ids @property def num_tokens(self) -> int: return len(self._all_token_ids) - + @property def num_tokens_with_spec(self) -> int: return len(self._all_token_ids) + len(self._spec_token_ids) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 7cb807173db1..854317077641 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -1,19 +1,22 @@ import torch import torch.nn as nn -from vllm.v1.outputs import SamplerOutput + from vllm.logger import init_logger +from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata logger = init_logger(__name__) + + class RejectionSampler(nn.Module): - def sample(self, - logits: torch.Tensor, + + def sample(self, logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> SamplerOutput: # num_reqs x [num_specuated_tokens] spec_token_ids = sampling_metadata.spec_token_ids # only argmax is supported for now output_token_ids_cpu = logits.argmax(dim=-1).view(-1).tolist() - + sampled_token_ids = [] # Stop at the first mismatch place. # spec_tokens: [1, 2, 3] @@ -22,7 +25,9 @@ def sample(self, output_token_start_idx = 0 for spec_tokens in spec_token_ids: num_spec_tokens = len(spec_tokens) - output_tokens = output_token_ids_cpu[output_token_start_idx: output_token_start_idx + 1 + num_spec_tokens] + output_tokens = output_token_ids_cpu[ + output_token_start_idx:output_token_start_idx + 1 + + num_spec_tokens] i = 0 while i < len(spec_tokens): if spec_tokens[i] != output_tokens[i]: @@ -30,17 +35,13 @@ def sample(self, i += 1 # +1 to include the bonus token. i += 1 - output_tokens = output_tokens[:i] + output_tokens = output_tokens[:i] sampled_token_ids.append(output_tokens) output_token_start_idx += num_spec_tokens + 1 - print("Proposed token ids", spec_tokens._x) - print("Output token ids", output_tokens) assert output_token_start_idx == len(output_token_ids_cpu) - + return SamplerOutput(sampled_token_ids=sampled_token_ids, - logprob_token_ids=None, - logprobs=None, - prompt_logprob_token_ids=None, - prompt_logprobs=None) - - \ No newline at end of file + logprob_token_ids=None, + logprobs=None, + prompt_logprob_token_ids=None, + prompt_logprobs=None) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 9cb913aa4425..4ece163d1f44 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -28,10 +28,10 @@ def forward( ) -> SamplerOutput: if sampling_metadata.rejection_sampling: return self.rejection_sampler.sample( - logits, + logits, sampling_metadata, ) - + needs_logprobs = sampling_metadata.max_num_logprobs > 0 if needs_logprobs: # NOTE(woosuk): Use the original logits (before any penalties or diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 1902d8eddc82..b86dd48ae3ac 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -364,6 +364,7 @@ def make_sampling_metadata( # statistics. output_token_ids.append(req_id_output_token_ids[req_id]) if rejection_sampling: + assert req_id_to_spec_token_ids is not None spec_token_ids.append(req_id_to_spec_token_ids[req_id]) return SamplingMetadata( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4bead0742b61..7c86247698c5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,6 +1,6 @@ import gc import time -from typing import TYPE_CHECKING, Dict, List, Tuple, cast, Optional +from typing import TYPE_CHECKING, Dict, List, Tuple, cast import numpy as np import torch @@ -383,24 +383,32 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput") \ # where M is the max_model_len. token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) - + # Add spec decode tokens to input_batch.token_ids_cpu. # Get spec decode logits indices. spec_query_end_loc = 0 - spec_decode_logits_indices = [] + spec_decode_logits_indices: List[int] = [] for i, req_id in enumerate(self.input_batch.req_ids): if i == num_reqs: break - req_num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] + assert req_id is not None + req_num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] num_compute_tokens = self.input_batch.num_computed_tokens_cpu[i] spec_query_end_loc += req_num_scheduled_tokens - spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, []) for j, spec_token_id in enumerate(spec_token_ids): - # +1 here because the input for verification is [last_output_token_id] + spec_token_ids - self.input_batch.token_ids_cpu[i, num_compute_tokens + 1 + j] = spec_token_id - # -1 here because the input for verification is [last_output_token_id] + spec_token_ids - spec_decode_logits_indices.extend(range(spec_query_end_loc - len(spec_token_ids) - 1, spec_query_end_loc)) - + # +1 here because the input for verification is + # [last_output_token_id] + spec_token_ids + self.input_batch.token_ids_cpu[i, num_compute_tokens + 1 + + j] = spec_token_id + # -1 here because the input for verification is + # [last_output_token_id] + spec_token_ids + spec_decode_logits_indices.extend( + range(spec_query_end_loc - len(spec_token_ids) - 1, + spec_query_end_loc)) + # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. @@ -555,16 +563,17 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput") \ cu_prefix_kv_lens=cu_prefix_kv_lens, cu_suffix_kv_lens=cu_suffix_kv_lens, ) - + if scheduler_output.use_spec_decode: - logits_indices = torch.tensor(spec_decode_logits_indices, device=self.device) + logits_indices = torch.tensor(spec_decode_logits_indices, + device=self.device) else: - # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial - # request in the batch. While we should not sample any token from this - # partial request, we do so for simplicity. We will ignore the sampled - # token from the partial request. + # NOTE(woosuk): Due to chunked prefills, there can be at most 1 + # partial request in the batch. While we should not sample any + # token from this partial request, we do so for simplicity. + # We will ignore the sampled token from the partial request. # TODO: Support prompt logprobs. - logits_indices = query_start_loc[1:] - 1 + logits_indices = query_start_loc[1:] - 1 return attn_metadata, logits_indices def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): @@ -639,8 +648,7 @@ def _prepare_sampling( for req_id, req in self.requests.items()} sampling_metadata = self.input_batch.make_sampling_metadata( - req_id_output_token_ids, - skip_copy, + req_id_output_token_ids, skip_copy, scheduler_output.use_spec_decode, scheduler_output.scheduled_spec_decode_tokens) return sampling_metadata @@ -803,16 +811,19 @@ def execute_model( break assert req_id is not None req_state = self.requests[req_id] - seq_len = req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id] - # assert seq_len <= req_state.num_tokens + seq_len = req_state.num_computed_tokens + \ + scheduler_output.num_scheduled_tokens[req_id] if seq_len >= req_state.num_tokens: - print(req_state.req_id, "output_token_ids", sampled_token_ids[i], req_state.num_computed_tokens) + print(req_state.req_id, "output_token_ids", + sampled_token_ids[i], req_state.num_computed_tokens) # We don't rewind the generator state for requests now # because spec decode only supports greedy decoding for now. token_ids = sampled_token_ids[i] spec_token_ids = spec_tokens.get(req_id, []) for j, token_id in enumerate(token_ids): - self.input_batch.token_ids_cpu[i, seq_len - len(spec_token_ids) + j] = token_id + self.input_batch.token_ids_cpu[i, seq_len - + len(spec_token_ids) + + j] = token_id self.input_batch.num_tokens[i] += len(token_ids) req_state.output_token_ids.extend(token_ids) else: From b0c5d25c59432a32b257dac3e08fa6bae794ba5c Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 20 Jan 2025 12:07:23 -0800 Subject: [PATCH 10/75] basic test --- examples/offline_inference_spec_decode.py | 25 --------------- tests/v1/e2e/test_basic_specdecode.py | 39 +++++++++++++++++++++++ 2 files changed, 39 insertions(+), 25 deletions(-) delete mode 100644 examples/offline_inference_spec_decode.py create mode 100644 tests/v1/e2e/test_basic_specdecode.py diff --git a/examples/offline_inference_spec_decode.py b/examples/offline_inference_spec_decode.py deleted file mode 100644 index 9b5237344613..000000000000 --- a/examples/offline_inference_spec_decode.py +++ /dev/null @@ -1,25 +0,0 @@ -from vllm import LLM, SamplingParams - -prompts = [ - "The future of AI is", - "Speculative decoding is a method", -] -sampling_params = SamplingParams(temperature=0.0) - -llm = LLM( - model="meta-llama/Meta-Llama-3-8B-Instruct", - # tensor_parallel_size=1, - # speculative_model="/data/lily/eagle-8b-instruct-model", - # speculative_draft_tensor_parallel_size=1, - speculative_model='[ngram]', - ngram_prompt_lookup_max=5, - ngram_prompt_lookup_min=3, - num_speculative_tokens=3 -) - -outputs = llm.generate(prompts, sampling_params) - -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file diff --git a/tests/v1/e2e/test_basic_specdecode.py b/tests/v1/e2e/test_basic_specdecode.py new file mode 100644 index 000000000000..1852f9242754 --- /dev/null +++ b/tests/v1/e2e/test_basic_specdecode.py @@ -0,0 +1,39 @@ +from vllm import LLM, SamplingParams + +prompts = [ + "The future of AI is", + "This is a basic spec decode test", +] +# Only support greedy for now +sampling_params = SamplingParams(temperature=0) + + + +def test_basic_specdecode(monkeypatch): + ''' + Compare the outputs of a original LLM and a speculative LLM + should be the same. + ''' + prompts = [ "The future of AI is", + "This is a basic spec decode test",] + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + model = "meta-llama/Meta-Llama-3-8B-Instruct" + + ref_llm = LLM(model=model) + ref_outputs = ref_llm.generate(prompts, sampling_params) + del ref_llm + # print(ref_outputs.outputs[0].text) + + spec_llm = LLM( + model=model, + speculative_model='[ngram]', + ngram_prompt_lookup_max=5, + ngram_prompt_lookup_min=3, + num_speculative_tokens=3 + ) + spec_outputs = spec_llm.generate(prompts, sampling_params) + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + assert ref_output.outputs[0].text == spec_output.outputs[0].text, \ + f"ref_output: {ref_output.outputs[0].text}, spec_output: {spec_output.outputs[0].text}" + del spec_llm \ No newline at end of file From d5ee0812820c0c1593d887999418a06c6223acf1 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 20 Jan 2025 12:09:26 -0800 Subject: [PATCH 11/75] minor --- tests/v1/e2e/test_basic_specdecode.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/tests/v1/e2e/test_basic_specdecode.py b/tests/v1/e2e/test_basic_specdecode.py index 1852f9242754..012dfdb27def 100644 --- a/tests/v1/e2e/test_basic_specdecode.py +++ b/tests/v1/e2e/test_basic_specdecode.py @@ -8,32 +8,27 @@ sampling_params = SamplingParams(temperature=0) - def test_basic_specdecode(monkeypatch): ''' Compare the outputs of a original LLM and a speculative LLM should be the same. ''' - prompts = [ "The future of AI is", - "This is a basic spec decode test",] with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") model = "meta-llama/Meta-Llama-3-8B-Instruct" - + ref_llm = LLM(model=model) ref_outputs = ref_llm.generate(prompts, sampling_params) del ref_llm # print(ref_outputs.outputs[0].text) - spec_llm = LLM( - model=model, - speculative_model='[ngram]', - ngram_prompt_lookup_max=5, - ngram_prompt_lookup_min=3, - num_speculative_tokens=3 - ) + spec_llm = LLM(model=model, + speculative_model='[ngram]', + ngram_prompt_lookup_max=5, + ngram_prompt_lookup_min=3, + num_speculative_tokens=3) spec_outputs = spec_llm.generate(prompts, sampling_params) for ref_output, spec_output in zip(ref_outputs, spec_outputs): assert ref_output.outputs[0].text == spec_output.outputs[0].text, \ f"ref_output: {ref_output.outputs[0].text}, spec_output: {spec_output.outputs[0].text}" - del spec_llm \ No newline at end of file + del spec_llm From 4e11585e4de66879fb195152e7a843987a01aca4 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 20 Jan 2025 12:12:17 -0800 Subject: [PATCH 12/75] minor --- tests/v1/e2e/test_basic_specdecode.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/v1/e2e/test_basic_specdecode.py b/tests/v1/e2e/test_basic_specdecode.py index 012dfdb27def..1aecc6928e08 100644 --- a/tests/v1/e2e/test_basic_specdecode.py +++ b/tests/v1/e2e/test_basic_specdecode.py @@ -30,5 +30,6 @@ def test_basic_specdecode(monkeypatch): spec_outputs = spec_llm.generate(prompts, sampling_params) for ref_output, spec_output in zip(ref_outputs, spec_outputs): assert ref_output.outputs[0].text == spec_output.outputs[0].text, \ - f"ref_output: {ref_output.outputs[0].text}, spec_output: {spec_output.outputs[0].text}" + (f"ref_output: {ref_output.outputs[0].text}," + f"spec_output: {spec_output.outputs[0].text}") del spec_llm From f915eda7e08d1ee2a29a3e72d0f9200a077257b1 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 20 Jan 2025 14:06:49 -0800 Subject: [PATCH 13/75] stop checking --- vllm/v1/core/scheduler.py | 45 ++++++++++++++++++++++++++++++--------- vllm/v1/request.py | 7 ++++++ 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index f2626111dfa2..b11132d12f96 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -480,29 +480,54 @@ def update_from_output( scheduler_stats=self.make_stats(), ) - # TODO: the following logic does not consider - # when multiple tokens are generated in a - # single forward pass + def _crop_request(self, request: Request, num_total_token: int) -> None: + """Truncate the request to the num_total_token. + We do not need to update the input batch because + it will be updated in the next execute_model call's + _update_states method, where the request data is aligned + with the data in the persistent batch. + """ + request.crop(num_total_token) + def _check_stop(self, request: Request) -> bool: + """ + Check if the request should be stopped. + The function should handle both single token generation or + multiple token generation (e.g., spec decode) per step. + """ if (request.num_tokens >= self.max_model_len or request.num_output_tokens >= request.max_tokens): request.status = RequestStatus.FINISHED_LENGTH_CAPPED + num_total_token = min( + self.max_model_len, request.num_tokens, + request.max_tokens + request.num_prompt_tokens, + request.num_output_tokens + request.num_output_tokens) + self._crop_request(request, num_total_token) self._free_request(request) return True sampling_params = request.sampling_params - last_token_id = request.output_token_ids[-1] if (not sampling_params.ignore_eos - and last_token_id == request.eos_token_id): + and request.eos_token_id in request.output_token_ids): + assert request.eos_token_id is not None request.status = RequestStatus.FINISHED_STOPPED + num_total_token = request.num_prompt_tokens + \ + request.output_token_ids.index(request.eos_token_id) + 1 + self._crop_request(request, num_total_token) self._free_request(request) return True - if last_token_id in (sampling_params.stop_token_ids or ()): - request.status = RequestStatus.FINISHED_STOPPED - request.stop_reason = last_token_id - self._free_request(request) - return True + stop_token_ids = set(sampling_params.stop_token_ids or set([])) + output_token_ids = set(request.output_token_ids) + for stop_token_id in stop_token_ids: + if stop_token_id in output_token_ids: + request.status = RequestStatus.FINISHED_STOPPED + request.stop_reason = stop_token_id + num_total_token = request.num_prompt_tokens + \ + request.output_token_ids.index(stop_token_id) + 1 + self._crop_request(request, num_total_token) + self._free_request(request) + return True return False def add_request(self, request: Request) -> None: diff --git a/vllm/v1/request.py b/vllm/v1/request.py index b38d43decb1a..1bea0b1cefa4 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -157,6 +157,13 @@ def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None: def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None: self._kv_block_hashes.append(block_hash) + def crop(self, num_total_token: int) -> None: + if num_total_token < self.num_prompt_tokens: + raise ValueError("Cannot crop the prompt tokens.") + num_output_token = num_total_token - self.num_prompt_tokens + self._output_token_ids = self._output_token_ids[:num_output_token] + self._all_token_ids = self._all_token_ids[:num_total_token] + class RequestStatus(enum.IntEnum): """Status of a request.""" From bd8ac07f38f751c8a07ad2cceae7232a59ad7e0c Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 20 Jan 2025 20:54:50 -0800 Subject: [PATCH 14/75] test for stop checking --- tests/v1/core/test_stop_checking.py | 95 +++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 tests/v1/core/test_stop_checking.py diff --git a/tests/v1/core/test_stop_checking.py b/tests/v1/core/test_stop_checking.py new file mode 100644 index 000000000000..8a74caa8c308 --- /dev/null +++ b/tests/v1/core/test_stop_checking.py @@ -0,0 +1,95 @@ +import pytest + +from typing import List +from vllm.config import SchedulerConfig, ModelConfig, CacheConfig +from vllm.v1.core.scheduler import Scheduler +from vllm.v1.request import Request, RequestStatus, SamplingParams + + +EOS_TOKEN_ID = 50256 + +@pytest.fixture +def scheduler(): + cache_config=CacheConfig(block_size=16, + gpu_memory_utilization=0.9, + swap_space=0.1, + cache_dtype="auto") + cache_config.num_gpu_blocks = 100 + return Scheduler( + scheduler_config=SchedulerConfig(), + model_config=ModelConfig(model="facebook/opt-125m", + task="auto", + tokenizer="test_tokenizer", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=42), + cache_config=cache_config, + lora_config=None) + + +def _create_test_request(request_id: str, + max_tokens: int, + stop_token_ids: List[int]) -> Request: + return Request( + request_id=request_id, + prompt="test prompt", + prompt_token_ids=[1, 2, 3], + multi_modal_inputs=None, + multi_modal_hashes=None, + multi_modal_placeholders=None, + sampling_params=SamplingParams(max_tokens=max_tokens, + stop_token_ids=stop_token_ids), + eos_token_id=EOS_TOKEN_ID, + arrival_time=0.0 + ) + +def test_multiple_stop_tokens(scheduler): + """Test with stop when generating multiple tokens""" + # Nonstop case + request = _create_test_request("test1", 100, stop_token_ids=[42, 43, 44]) + scheduler.requests[request.request_id] = request + request.append_output_token_ids([4, 5, 6, 7, 8]) + result = scheduler._check_stop(request) + assert result is False + + # EOS token is generated in the beginnig of the output tokens + request = _create_test_request("test1", 100, stop_token_ids=[42, 43, 44]) + scheduler.requests[request.request_id] = request + request.append_output_token_ids([EOS_TOKEN_ID, 5, EOS_TOKEN_ID, 7, 43, 5]) + result = scheduler._check_stop(request) + assert result is True + assert request.status == RequestStatus.FINISHED_STOPPED + assert request.request_id in scheduler.finished_req_ids + # Should be cropped at the first stop token + assert len(request.output_token_ids) == 1 + assert list(request.output_token_ids) == [EOS_TOKEN_ID] + + # Stop token, 43 is one of the stop tokens + request = _create_test_request("test1", 100, stop_token_ids=[42, 43, 44]) + scheduler.requests[request.request_id] = request + request.append_output_token_ids([4, 5, 43, 7, 43, 5]) + result = scheduler._check_stop(request) + assert result is True + assert request.status == RequestStatus.FINISHED_STOPPED + assert request.stop_reason == 43 + assert request.request_id in scheduler.finished_req_ids + # Should be cropped at the first stop token + assert len(request.output_token_ids) == 3 + assert list(request.output_token_ids) == [4, 5, 43] + + # Max tokens, should be cropped when reaching the max tokens + max_tokens = 2 + request = _create_test_request("test2", max_tokens, stop_token_ids=[42, 43, 44]) + scheduler.requests[request.request_id] = request + output_token_ids = [4, 5, 43, 7, 43, 5] + request.append_output_token_ids(output_token_ids) + result = scheduler._check_stop(request) + assert result is True + assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED + assert request.request_id in scheduler.finished_req_ids + # Should be cropped at the first stop token + assert len(request.output_token_ids) == max_tokens + assert list(request.output_token_ids) == output_token_ids[:max_tokens] + + From 008a41e3cedca1f0c1d9240872e03a9dce73b019 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 21 Jan 2025 18:37:20 -0800 Subject: [PATCH 15/75] style and disable scheduling chunked requests --- tests/v1/core/test_stop_checking.py | 78 ++++++++++++++--------------- vllm/v1/engine/core.py | 26 ++++++---- 2 files changed, 53 insertions(+), 51 deletions(-) diff --git a/tests/v1/core/test_stop_checking.py b/tests/v1/core/test_stop_checking.py index 8a74caa8c308..4913d9b9ab6c 100644 --- a/tests/v1/core/test_stop_checking.py +++ b/tests/v1/core/test_stop_checking.py @@ -1,49 +1,47 @@ +from typing import List + import pytest -from typing import List -from vllm.config import SchedulerConfig, ModelConfig, CacheConfig +from vllm.config import CacheConfig, ModelConfig, SchedulerConfig from vllm.v1.core.scheduler import Scheduler from vllm.v1.request import Request, RequestStatus, SamplingParams - EOS_TOKEN_ID = 50256 + @pytest.fixture def scheduler(): - cache_config=CacheConfig(block_size=16, - gpu_memory_utilization=0.9, - swap_space=0.1, - cache_dtype="auto") + cache_config = CacheConfig(block_size=16, + gpu_memory_utilization=0.9, + swap_space=0.1, + cache_dtype="auto") cache_config.num_gpu_blocks = 100 - return Scheduler( - scheduler_config=SchedulerConfig(), - model_config=ModelConfig(model="facebook/opt-125m", - task="auto", - tokenizer="test_tokenizer", - tokenizer_mode="auto", - trust_remote_code=False, - dtype="float16", - seed=42), - cache_config=cache_config, - lora_config=None) + return Scheduler(scheduler_config=SchedulerConfig(), + model_config=ModelConfig(model="facebook/opt-125m", + task="auto", + tokenizer="test_tokenizer", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=42), + cache_config=cache_config, + lora_config=None) -def _create_test_request(request_id: str, - max_tokens: int, +def _create_test_request(request_id: str, max_tokens: int, stop_token_ids: List[int]) -> Request: - return Request( - request_id=request_id, - prompt="test prompt", - prompt_token_ids=[1, 2, 3], - multi_modal_inputs=None, - multi_modal_hashes=None, - multi_modal_placeholders=None, - sampling_params=SamplingParams(max_tokens=max_tokens, - stop_token_ids=stop_token_ids), - eos_token_id=EOS_TOKEN_ID, - arrival_time=0.0 - ) - + return Request(request_id=request_id, + prompt="test prompt", + prompt_token_ids=[1, 2, 3], + multi_modal_inputs=None, + multi_modal_hashes=None, + multi_modal_placeholders=None, + sampling_params=SamplingParams( + max_tokens=max_tokens, stop_token_ids=stop_token_ids), + eos_token_id=EOS_TOKEN_ID, + arrival_time=0.0) + + def test_multiple_stop_tokens(scheduler): """Test with stop when generating multiple tokens""" # Nonstop case @@ -52,8 +50,8 @@ def test_multiple_stop_tokens(scheduler): request.append_output_token_ids([4, 5, 6, 7, 8]) result = scheduler._check_stop(request) assert result is False - - # EOS token is generated in the beginnig of the output tokens + + # EOS token is generated in the beginning of the output tokens request = _create_test_request("test1", 100, stop_token_ids=[42, 43, 44]) scheduler.requests[request.request_id] = request request.append_output_token_ids([EOS_TOKEN_ID, 5, EOS_TOKEN_ID, 7, 43, 5]) @@ -64,7 +62,7 @@ def test_multiple_stop_tokens(scheduler): # Should be cropped at the first stop token assert len(request.output_token_ids) == 1 assert list(request.output_token_ids) == [EOS_TOKEN_ID] - + # Stop token, 43 is one of the stop tokens request = _create_test_request("test1", 100, stop_token_ids=[42, 43, 44]) scheduler.requests[request.request_id] = request @@ -77,10 +75,12 @@ def test_multiple_stop_tokens(scheduler): # Should be cropped at the first stop token assert len(request.output_token_ids) == 3 assert list(request.output_token_ids) == [4, 5, 43] - + # Max tokens, should be cropped when reaching the max tokens max_tokens = 2 - request = _create_test_request("test2", max_tokens, stop_token_ids=[42, 43, 44]) + request = _create_test_request("test2", + max_tokens, + stop_token_ids=[42, 43, 44]) scheduler.requests[request.request_id] = request output_token_ids = [4, 5, 43, 7, 43, 5] request.append_output_token_ids(output_token_ids) @@ -91,5 +91,3 @@ def test_multiple_stop_tokens(scheduler): # Should be cropped at the first stop token assert len(request.output_token_ids) == max_tokens assert list(request.output_token_ids) == output_token_ids[:max_tokens] - - diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index b598bb6d9a9c..556959532f42 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -118,6 +118,19 @@ def abort_requests(self, request_ids: List[str]): self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED) + def propose_tokens(self): + assert self.scheduler.speculative_config is not None + # Append tokens to requests directly + # to mimic ngram proposal. + # Only change requests in the decoding phase. + # We don't handle spec decode kv cache for now. + for req in self.scheduler.running: + # Ignore requests that are doing chunked prefill. + if req.num_computed_tokens >= req.num_tokens: + req.append_spec_token_ids( + [1] * + self.scheduler.speculative_config.num_speculative_tokens) + def step(self) -> EngineCoreOutputs: """Schedule, execute, and make output.""" @@ -125,17 +138,8 @@ def step(self) -> EngineCoreOutputs: return EngineCoreOutputs( outputs=[], scheduler_stats=self.scheduler.make_stats()) - if self.scheduler.speculative_config and \ - self.scheduler.speculative_config.num_speculative_tokens > 0: - # Append tokens to requests directly - # to mimic ngram proposal. - # Only change requests in the running queue. - # We don't do spec decode in the prefill phase for now. - # We don't handle spec decode kv cache for now. - for req in self.scheduler.running: - req.append_spec_token_ids( - [1] * - self.scheduler.speculative_config.num_speculative_tokens) + if self.scheduler.speculative_config: + self.propose_tokens() scheduler_output = self.scheduler.schedule() output = self.model_executor.execute_model(scheduler_output) From 784b24aaa94bf9e51f6e75a728cec06ac88b0f09 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 21 Jan 2025 18:49:59 -0800 Subject: [PATCH 16/75] signed-off-by Signed-off-by: LiuXiaoxuanPKU From f3f6ebc62aaccf78ac147decc16cb9c92f214e1a Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 23 Jan 2025 18:50:24 -0800 Subject: [PATCH 17/75] ngram proposer --- tests/v1/sample/test_rejection_sampler.py | 145 ++++++++++++++++++++++ tests/v1/spec_decode/test_ngram.py | 34 +++++ vllm/v1/core/kv_cache_manager.py | 13 +- vllm/v1/core/scheduler.py | 6 +- vllm/v1/engine/core.py | 29 +++-- vllm/v1/spec_decode/ngram_proposer.py | 82 ++++++++++++ 6 files changed, 287 insertions(+), 22 deletions(-) create mode 100644 tests/v1/sample/test_rejection_sampler.py create mode 100644 tests/v1/spec_decode/test_ngram.py create mode 100644 vllm/v1/spec_decode/ngram_proposer.py diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py new file mode 100644 index 000000000000..d882e333624b --- /dev/null +++ b/tests/v1/sample/test_rejection_sampler.py @@ -0,0 +1,145 @@ +import pytest +import torch +from typing import List +from vllm.v1.sample.rejection_sampler import RejectionSampler +from vllm.v1.sample.metadata import SamplingMetadata + +@pytest.fixture +def sampler(): + return RejectionSampler() + +def create_logits_tensor(token_ids: List[int], vocab_size: int = 100) -> torch.Tensor: + """Helper function to create logits tensor that will produce desired token ids on argmax""" + logits = torch.full((len(token_ids), vocab_size), -100.0) + for i, token_id in enumerate(token_ids): + logits[i, token_id] = 100.0 + return logits + +def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata: + return SamplingMetadata( + temperature=0.0, + all_greedy=True, + all_random=False, + rejection_sampling=True, + spec_token_ids=spec_tokens, + top_p=None, + top_k=None, + no_top_p=False, + no_top_k=False, + generators={}, + max_num_logprobs=0, + no_penalties=False, + prompt_token_ids=None, + frequency_penalties=torch.tensor([]), + presence_penalties=torch.tensor([]), + repetition_penalties=torch.tensor([]), + output_token_ids=[], + min_tokens=[], + stop_token_ids=[] + ) + +def test_perfect_match(sampler): + """Test when output tokens perfectly match speculated tokens""" + spec_tokens = [[1, 2, 3]] + output_tokens = [1, 2, 3, 4] # 4 is the bonus token + + metadata = create_sampling_metadata(spec_tokens) + logits = create_logits_tensor(output_tokens) + + output = sampler.sample(logits, metadata) + assert output.sampled_token_ids == [[1, 2, 3, 4]] + +def test_early_mismatch(sampler): + """Test when there's an early mismatch in tokens""" + spec_tokens = [[1, 2, 3]] + output_tokens = [1, 5, 3, 4] # Mismatch at position 1 + + metadata = create_sampling_metadata(spec_tokens) + logits = create_logits_tensor(output_tokens) + + output = sampler.sample(logits, metadata) + assert output.sampled_token_ids == [[1, 5]] + +def test_multiple_sequences(sampler): + """Test handling multiple sequences of speculated tokens""" + spec_tokens = [[1, 2], [3, 4]] + output_tokens = [1, 2, 5, 3, 4, 6] # Two sequences with bonus tokens 5 and 6 + + metadata = create_sampling_metadata(spec_tokens) + logits = create_logits_tensor(output_tokens) + + output = sampler.sample(logits, metadata) + assert output.sampled_token_ids == [[1, 2, 5], [3, 4, 6]] + +def test_single_token_sequence(sampler): + """Test handling sequences with single token""" + spec_tokens = [[1]] + output_tokens = [1, 2] # Single token with bonus token 2 + + metadata = create_sampling_metadata(spec_tokens) + logits = create_logits_tensor(output_tokens) + + output = sampler.sample(logits, metadata) + assert output.sampled_token_ids == [[1, 2]] + +def test_empty_sequence(sampler): + """Test handling empty sequence of speculated tokens""" + spec_tokens: List[List[int]] = [[]] + output_tokens = [5] # Just the bonus token + + metadata = create_sampling_metadata(spec_tokens) + logits = create_logits_tensor(output_tokens) + + output = sampler.sample(logits, metadata) + assert output.sampled_token_ids == [[5]] + +def test_multiple_mismatches(sampler): + """Test handling multiple sequences with mismatches""" + spec_tokens = [[1, 2, 3], [4, 5, 6]] + output_tokens = [1, 2, 7, 6, 4, 8, 6, 9] # Mismatches in both sequences + + metadata = create_sampling_metadata(spec_tokens) + logits = create_logits_tensor(output_tokens) + + output = sampler.sample(logits, metadata) + assert output.sampled_token_ids == [[1, 2, 7], [4, 8]] + +@pytest.mark.parametrize("spec_tokens,output_tokens,expected", [ + ([[1, 2]], [1, 2, 3], [[1, 2, 3]]), # Perfect match with bonus + ([[1]], [2, 3], [[2]]), # First mismatch + ([[1, 2], [3, 4]], [1, 5, 6, 3, 4, 7], [[1, 5], [3, 4, 7]]), # Mixed matches +]) +def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected): + """Parametrized test for various matching scenarios""" + metadata = create_sampling_metadata(spec_tokens) + logits = create_logits_tensor(output_tokens) + + output = sampler.sample(logits, metadata) + assert output.sampled_token_ids == expected + +def test_logits_shape_handling(sampler): + """Test handling of different logits tensor shapes""" + spec_tokens = [[1, 2]] + output_tokens = [1, 2, 3] + vocab_size = 1000 + + metadata = create_sampling_metadata(spec_tokens) + logits = create_logits_tensor(output_tokens, vocab_size) + + output = sampler.sample(logits, metadata) + assert output.sampled_token_ids == [[1, 2, 3]] + assert logits.shape[-1] == vocab_size + +def test_none_outputs(sampler): + """Test that other output fields are None as expected""" + spec_tokens = [[1]] + output_tokens = [1, 2] + + metadata = create_sampling_metadata(spec_tokens) + logits = create_logits_tensor(output_tokens) + + output = sampler.sample(logits, metadata) + assert output.logprob_token_ids is None + assert output.logprobs is None + assert output.prompt_logprob_token_ids is None + assert output.prompt_logprobs is None \ No newline at end of file diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py new file mode 100644 index 000000000000..654b7ecc21d7 --- /dev/null +++ b/tests/v1/spec_decode/test_ngram.py @@ -0,0 +1,34 @@ +import pytest +from typing import List, Optional +from vllm.v1.utils import ConstantList +from vllm.v1.spec_decode.ngram_proposer import NgramProposer + +@pytest.fixture +def proposer(): + return NgramProposer() + +def test_kmp_lps_array(proposer): + assert proposer._kmp_lps_array([]) == [] + assert proposer._kmp_lps_array([1]) == [0] + assert proposer._kmp_lps_array([1, 1, 1]) == [0, 1, 2] + assert proposer._kmp_lps_array([1, 2, 3, 4]) == [0, 0, 0, 0] + assert proposer._kmp_lps_array([1, 2, 1, 2, 3]) == [0, 0, 1, 2, 0] + +def test_find_subarray_kmp(proposer): + X = ConstantList([1, 2, 3, 4, 1, 2, 3, 5, 6]) + + assert proposer._find_subarray_kmp(X, [], 2) == [1, 2] + assert proposer._find_subarray_kmp(X, [7, 8], 1) is None + assert proposer._find_subarray_kmp(X, [1, 2, 3], 2) == [1, 2, 3, 4, 1] + assert proposer._find_subarray_kmp(X, [5, 6], 1) == [5, 6] + +def test_propose(proposer): + context = ConstantList([1, 2, 3, 4, 1, 2, 3, 5, 6]) + assert proposer.propose(context, n=3, k=2) is None + assert proposer.propose(context, n=2, k=1) is None + assert proposer.propose(context, n=4, k=2) is None + + context = ConstantList([1, 2, 3, 4, 1, 2]) + assert proposer.propose(context, n=2, k=2) == [3, 4] + assert proposer.propose(context, n=2, k=1) == [3] + assert proposer.propose(context, n=3, k=2) is None \ No newline at end of file diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 63210bd36e08..3c169a02c3ec 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -164,15 +164,10 @@ def append_slots( num_computed_full_blocks = (request.num_computed_tokens // self.block_size) - # NOTE(rickyx): We are assuming the `num_tokens` are actual - # tokens rather than lookahead slots (e.g. for speculative decoding). - # TODO(rickyx): When supporting speculative decoding, we will need to - # differentiate between them so that we can know how many blocks are - # full after appending the actual tokens. - - # Does not include speculative tokens. - # FIXME: The logic is not correct because - # we never count speculative tokens that are accepted. + # When calculating new full blocks, we exclude speculative tokens. + # We only cache blocks where token_ids are valid. KV cache of + # speculative tokens will be valid once these tokens are accepted + # (tracked by num_computed_tokens). num_cached_tokens = request.num_computed_tokens + num_tokens - len( request.spec_token_ids) num_full_blocks_after_append = num_cached_tokens // self.block_size diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index b11132d12f96..092e79a0faa0 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -435,11 +435,6 @@ def update_from_output( request.num_computed_tokens += num_scheduled_tokens[req_id] - ( len(request.spec_token_ids) + 1 - len(token_ids)) - # When the request's num_computed_tokens catches up its num_tokens, - # the request generates output tokens. Otherwise, we ignore the - # sampler output for the request. - assert request.num_computed_tokens <= request.num_tokens - cached_encoder_input_ids = ( self.encoder_cache_manager.get_cached_input_ids(request)) for input_id in list(cached_encoder_input_ids): @@ -455,6 +450,7 @@ def update_from_output( request.append_output_token_ids(token_ids) num_new_tokens = len(token_ids) # TODO: Update the KV cache manager for prefix caching. + # self.kv_cache_manager.uncache_blocks(request) # Check for stop and update request state. # This must be called before me make the EngineCoreOutput. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 556959532f42..4ca5d267a556 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -25,6 +25,7 @@ from vllm.v1.executor.abstract import Executor from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import PickleEncoder +from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -67,6 +68,12 @@ def __init__( self.mm_input_mapper_server = MMInputMapperServer( vllm_config.model_config) + # TODO: find a better way to check if we are using ngram. + if self.scheduler.speculative_config: + assert self.scheduler.speculative_config.ngram_prompt_lookup_min \ + , "Only ngram spec decode is supported in V1." + self.proposer = NgramProposer() + def _initialize_kv_caches(self, vllm_config: VllmConfig) -> Tuple[int, int]: start = time.time() @@ -120,16 +127,22 @@ def abort_requests(self, request_ids: List[str]): def propose_tokens(self): assert self.scheduler.speculative_config is not None - # Append tokens to requests directly - # to mimic ngram proposal. - # Only change requests in the decoding phase. - # We don't handle spec decode kv cache for now. for req in self.scheduler.running: # Ignore requests that are doing chunked prefill. - if req.num_computed_tokens >= req.num_tokens: - req.append_spec_token_ids( - [1] * - self.scheduler.speculative_config.num_speculative_tokens) + if req.num_computed_tokens < req.num_tokens - 1: + print("**", req.num_computed_tokens, req.num_tokens) + continue + # Ignore requests that already have spec tokens. + if len(req.spec_token_ids) > 0: + continue + spec_tokens = self.proposer.propose( + req.all_token_ids, + self.scheduler.speculative_config.ngram_prompt_lookup_min, + self.scheduler.speculative_config.num_speculative_tokens, + ) + print(f"Proposed tokens: {spec_tokens}") + if spec_tokens: + req.append_spec_token_ids(spec_tokens) def step(self) -> EngineCoreOutputs: """Schedule, execute, and make output.""" diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py new file mode 100644 index 000000000000..10c5999f0a5c --- /dev/null +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -0,0 +1,82 @@ +from typing import List, Optional +from vllm.v1.utils import ConstantList + +class NgramProposer(): + def __init__(self): + pass + + def _kmp_lps_array(self, pattern: List[int]) -> List[int]: + """ + Build the lps (longest proper prefix which is also suffix) array for the pattern. + """ + lps = [0] * len(pattern) + prev_lps = 0 # length of the previous longest prefix suffix + i = 1 + + while i < len(pattern): + if pattern[i] == pattern[prev_lps]: + prev_lps += 1 + lps[i] = prev_lps + i += 1 + else: + if prev_lps != 0: + prev_lps = lps[prev_lps - 1] + else: + lps[i] = 0 + i += 1 + + return lps + + def _find_subarray_kmp(self, + X: List[int], + Y: List[int], + K: int) -> Optional[List[int]]: + """ + Returns the subarray starting at the first occurrence of Y in X, + plus K subsequent elements (if available). If not found, returns None. + """ + N = len(X) + M = len(Y) + + if M == 0: + # If Y is empty, + # let's define that it matches at index 0 + return X[:K] + + # Precompute lps array for Y + lps = self._kmp_lps_array(Y) + + i = 0 # index for X + j = 0 # index for Y + + while i < N: + if X[i] == Y[j]: + i += 1 + j += 1 + + # If we have matched the entire Y + if j == M: + # Found Y in X, gather the next K elements + start_index = i - M # Where the match started + return X[start_index : start_index + M + K] + else: + # Mismatch + if j != 0: + # Use the lps array to avoid re-checking elements + j = lps[j - 1] + else: + i += 1 + + # Y not found + return None + + def propose(self, + context_token_ids: ConstantList[int], + n: int, k: int) -> Optional[List[int]]: + ngrams = context_token_ids[-n:] + lookup_tokens = context_token_ids[:-n] + match_tokens = self._find_subarray_kmp(lookup_tokens, + ngrams, k) + if match_tokens is None: + return None + return match_tokens[n:] \ No newline at end of file From 5e7306e303ef4237fb77d6f5f6fbdb16f2e9ddcd Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 23 Jan 2025 19:29:44 -0800 Subject: [PATCH 18/75] style and minor output token fix --- tests/v1/sample/test_rejection_sampler.py | 116 +++++++++++++--------- tests/v1/spec_decode/test_ngram.py | 52 +++++----- vllm/v1/core/scheduler.py | 6 +- vllm/v1/engine/core.py | 1 - vllm/v1/spec_decode/ngram_proposer.py | 34 +++---- vllm/v1/worker/gpu_model_runner.py | 3 - 6 files changed, 115 insertions(+), 97 deletions(-) diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index d882e333624b..9ad299b8e293 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -1,145 +1,163 @@ +from typing import List + import pytest import torch -from typing import List -from vllm.v1.sample.rejection_sampler import RejectionSampler + from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.rejection_sampler import RejectionSampler + @pytest.fixture def sampler(): return RejectionSampler() -def create_logits_tensor(token_ids: List[int], vocab_size: int = 100) -> torch.Tensor: - """Helper function to create logits tensor that will produce desired token ids on argmax""" + +def create_logits_tensor(token_ids: List[int], + vocab_size: int = 100) -> torch.Tensor: + """Helper function to create logits tensor that + will produce desired token ids on argmax""" logits = torch.full((len(token_ids), vocab_size), -100.0) for i, token_id in enumerate(token_ids): logits[i, token_id] = 100.0 return logits + def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata: - return SamplingMetadata( - temperature=0.0, - all_greedy=True, - all_random=False, - rejection_sampling=True, - spec_token_ids=spec_tokens, - top_p=None, - top_k=None, - no_top_p=False, - no_top_k=False, - generators={}, - max_num_logprobs=0, - no_penalties=False, - prompt_token_ids=None, - frequency_penalties=torch.tensor([]), - presence_penalties=torch.tensor([]), - repetition_penalties=torch.tensor([]), - output_token_ids=[], - min_tokens=[], - stop_token_ids=[] - ) + return SamplingMetadata(temperature=0.0, + all_greedy=True, + all_random=False, + rejection_sampling=True, + spec_token_ids=spec_tokens, + top_p=None, + top_k=None, + no_top_p=False, + no_top_k=False, + generators={}, + max_num_logprobs=0, + no_penalties=False, + prompt_token_ids=None, + frequency_penalties=torch.tensor([]), + presence_penalties=torch.tensor([]), + repetition_penalties=torch.tensor([]), + output_token_ids=[], + min_tokens=[], + stop_token_ids=[]) + def test_perfect_match(sampler): """Test when output tokens perfectly match speculated tokens""" spec_tokens = [[1, 2, 3]] output_tokens = [1, 2, 3, 4] # 4 is the bonus token - + metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - + output = sampler.sample(logits, metadata) assert output.sampled_token_ids == [[1, 2, 3, 4]] + def test_early_mismatch(sampler): """Test when there's an early mismatch in tokens""" spec_tokens = [[1, 2, 3]] output_tokens = [1, 5, 3, 4] # Mismatch at position 1 - + metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - + output = sampler.sample(logits, metadata) assert output.sampled_token_ids == [[1, 5]] + def test_multiple_sequences(sampler): """Test handling multiple sequences of speculated tokens""" spec_tokens = [[1, 2], [3, 4]] - output_tokens = [1, 2, 5, 3, 4, 6] # Two sequences with bonus tokens 5 and 6 - + output_tokens = [1, 2, 5, 3, 4, + 6] # Two sequences with bonus tokens 5 and 6 + metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - + output = sampler.sample(logits, metadata) assert output.sampled_token_ids == [[1, 2, 5], [3, 4, 6]] + def test_single_token_sequence(sampler): """Test handling sequences with single token""" spec_tokens = [[1]] output_tokens = [1, 2] # Single token with bonus token 2 - + metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - + output = sampler.sample(logits, metadata) assert output.sampled_token_ids == [[1, 2]] + def test_empty_sequence(sampler): """Test handling empty sequence of speculated tokens""" spec_tokens: List[List[int]] = [[]] output_tokens = [5] # Just the bonus token - + metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - + output = sampler.sample(logits, metadata) assert output.sampled_token_ids == [[5]] + def test_multiple_mismatches(sampler): """Test handling multiple sequences with mismatches""" spec_tokens = [[1, 2, 3], [4, 5, 6]] output_tokens = [1, 2, 7, 6, 4, 8, 6, 9] # Mismatches in both sequences - + metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - + output = sampler.sample(logits, metadata) assert output.sampled_token_ids == [[1, 2, 7], [4, 8]] -@pytest.mark.parametrize("spec_tokens,output_tokens,expected", [ - ([[1, 2]], [1, 2, 3], [[1, 2, 3]]), # Perfect match with bonus - ([[1]], [2, 3], [[2]]), # First mismatch - ([[1, 2], [3, 4]], [1, 5, 6, 3, 4, 7], [[1, 5], [3, 4, 7]]), # Mixed matches -]) + +@pytest.mark.parametrize( + "spec_tokens,output_tokens,expected", + [ + ([[1, 2]], [1, 2, 3], [[1, 2, 3]]), # Perfect match with bonus + ([[1]], [2, 3], [[2]]), # First mismatch + ([[1, 2], [3, 4]], [1, 5, 6, 3, 4, 7], [[1, 5], [3, 4, 7] + ]), # Mixed matches + ]) def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected): """Parametrized test for various matching scenarios""" metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - + output = sampler.sample(logits, metadata) assert output.sampled_token_ids == expected + def test_logits_shape_handling(sampler): """Test handling of different logits tensor shapes""" spec_tokens = [[1, 2]] output_tokens = [1, 2, 3] vocab_size = 1000 - + metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens, vocab_size) - + output = sampler.sample(logits, metadata) assert output.sampled_token_ids == [[1, 2, 3]] assert logits.shape[-1] == vocab_size + def test_none_outputs(sampler): """Test that other output fields are None as expected""" spec_tokens = [[1]] output_tokens = [1, 2] - + metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - + output = sampler.sample(logits, metadata) assert output.logprob_token_ids is None assert output.logprobs is None assert output.prompt_logprob_token_ids is None - assert output.prompt_logprobs is None \ No newline at end of file + assert output.prompt_logprobs is None diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index 654b7ecc21d7..f1c209b58155 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -1,34 +1,38 @@ import pytest -from typing import List, Optional -from vllm.v1.utils import ConstantList + from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.utils import ConstantList + @pytest.fixture def proposer(): - return NgramProposer() + return NgramProposer() + def test_kmp_lps_array(proposer): - assert proposer._kmp_lps_array([]) == [] - assert proposer._kmp_lps_array([1]) == [0] - assert proposer._kmp_lps_array([1, 1, 1]) == [0, 1, 2] - assert proposer._kmp_lps_array([1, 2, 3, 4]) == [0, 0, 0, 0] - assert proposer._kmp_lps_array([1, 2, 1, 2, 3]) == [0, 0, 1, 2, 0] + assert proposer._kmp_lps_array([]) == [] + assert proposer._kmp_lps_array([1]) == [0] + assert proposer._kmp_lps_array([1, 1, 1]) == [0, 1, 2] + assert proposer._kmp_lps_array([1, 2, 3, 4]) == [0, 0, 0, 0] + assert proposer._kmp_lps_array([1, 2, 1, 2, 3]) == [0, 0, 1, 2, 0] + def test_find_subarray_kmp(proposer): - X = ConstantList([1, 2, 3, 4, 1, 2, 3, 5, 6]) - - assert proposer._find_subarray_kmp(X, [], 2) == [1, 2] - assert proposer._find_subarray_kmp(X, [7, 8], 1) is None - assert proposer._find_subarray_kmp(X, [1, 2, 3], 2) == [1, 2, 3, 4, 1] - assert proposer._find_subarray_kmp(X, [5, 6], 1) == [5, 6] - + X = ConstantList([1, 2, 3, 4, 1, 2, 3, 5, 6]) + + assert proposer._find_subarray_kmp(X, [], 2) == [1, 2] + assert proposer._find_subarray_kmp(X, [7, 8], 1) is None + assert proposer._find_subarray_kmp(X, [1, 2, 3], 2) == [1, 2, 3, 4, 1] + assert proposer._find_subarray_kmp(X, [5, 6], 1) == [5, 6] + + def test_propose(proposer): - context = ConstantList([1, 2, 3, 4, 1, 2, 3, 5, 6]) - assert proposer.propose(context, n=3, k=2) is None - assert proposer.propose(context, n=2, k=1) is None - assert proposer.propose(context, n=4, k=2) is None - - context = ConstantList([1, 2, 3, 4, 1, 2]) - assert proposer.propose(context, n=2, k=2) == [3, 4] - assert proposer.propose(context, n=2, k=1) == [3] - assert proposer.propose(context, n=3, k=2) is None \ No newline at end of file + context = ConstantList([1, 2, 3, 4, 1, 2, 3, 5, 6]) + assert proposer.propose(context, n=3, k=2) is None + assert proposer.propose(context, n=2, k=1) is None + assert proposer.propose(context, n=4, k=2) is None + + context = ConstantList([1, 2, 3, 4, 1, 2]) + assert proposer.propose(context, n=2, k=2) == [3, 4] + assert proposer.propose(context, n=2, k=1) == [3] + assert proposer.propose(context, n=3, k=2) is None diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 092e79a0faa0..cfea256f09fd 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -447,14 +447,14 @@ def update_from_output( if request.num_computed_tokens >= request.num_tokens: request.clear_spec_tokens() + num_tokens_before_step = request.num_tokens request.append_output_token_ids(token_ids) - num_new_tokens = len(token_ids) # TODO: Update the KV cache manager for prefix caching. - # self.kv_cache_manager.uncache_blocks(request) # Check for stop and update request state. # This must be called before me make the EngineCoreOutput. stopped = self._check_stop(request) + num_new_tokens = request.num_tokens - num_tokens_before_step # Add EngineCoreOutput for this Request. output = EngineCoreOutput( @@ -497,7 +497,7 @@ def _check_stop(self, request: Request) -> bool: num_total_token = min( self.max_model_len, request.num_tokens, request.max_tokens + request.num_prompt_tokens, - request.num_output_tokens + request.num_output_tokens) + request.num_output_tokens + request.num_prompt_tokens) self._crop_request(request, num_total_token) self._free_request(request) return True diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 4ca5d267a556..eae93f5c43ed 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -140,7 +140,6 @@ def propose_tokens(self): self.scheduler.speculative_config.ngram_prompt_lookup_min, self.scheduler.speculative_config.num_speculative_tokens, ) - print(f"Proposed tokens: {spec_tokens}") if spec_tokens: req.append_spec_token_ids(spec_tokens) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 10c5999f0a5c..95557cbb41c6 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -1,13 +1,17 @@ from typing import List, Optional + from vllm.v1.utils import ConstantList -class NgramProposer(): + +class NgramProposer: + def __init__(self): pass - + def _kmp_lps_array(self, pattern: List[int]) -> List[int]: """ - Build the lps (longest proper prefix which is also suffix) array for the pattern. + Build the lps (longest proper prefix which is also suffix) + array for the pattern. """ lps = [0] * len(pattern) prev_lps = 0 # length of the previous longest prefix suffix @@ -27,9 +31,7 @@ def _kmp_lps_array(self, pattern: List[int]) -> List[int]: return lps - def _find_subarray_kmp(self, - X: List[int], - Y: List[int], + def _find_subarray_kmp(self, X: List[int], Y: List[int], K: int) -> Optional[List[int]]: """ Returns the subarray starting at the first occurrence of Y in X, @@ -39,7 +41,7 @@ def _find_subarray_kmp(self, M = len(Y) if M == 0: - # If Y is empty, + # If Y is empty, # let's define that it matches at index 0 return X[:K] @@ -57,8 +59,8 @@ def _find_subarray_kmp(self, # If we have matched the entire Y if j == M: # Found Y in X, gather the next K elements - start_index = i - M # Where the match started - return X[start_index : start_index + M + K] + start_index = i - M # Where the match started + return X[start_index:start_index + M + K] else: # Mismatch if j != 0: @@ -66,17 +68,15 @@ def _find_subarray_kmp(self, j = lps[j - 1] else: i += 1 - + # Y not found return None - - def propose(self, - context_token_ids: ConstantList[int], - n: int, k: int) -> Optional[List[int]]: + + def propose(self, context_token_ids: ConstantList[int], n: int, + k: int) -> Optional[List[int]]: ngrams = context_token_ids[-n:] lookup_tokens = context_token_ids[:-n] - match_tokens = self._find_subarray_kmp(lookup_tokens, - ngrams, k) + match_tokens = self._find_subarray_kmp(lookup_tokens, ngrams, k) if match_tokens is None: return None - return match_tokens[n:] \ No newline at end of file + return match_tokens[n:] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7c86247698c5..f563000e59f3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -449,7 +449,6 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput") \ # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) - print("input_ids", self.input_ids[:total_num_scheduled_tokens]) if self.model_config.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions[:, :total_num_scheduled_tokens].copy_( @@ -814,8 +813,6 @@ def execute_model( seq_len = req_state.num_computed_tokens + \ scheduler_output.num_scheduled_tokens[req_id] if seq_len >= req_state.num_tokens: - print(req_state.req_id, "output_token_ids", - sampled_token_ids[i], req_state.num_computed_tokens) # We don't rewind the generator state for requests now # because spec decode only supports greedy decoding for now. token_ids = sampled_token_ids[i] From a26df8db580aa5927c58018adbcae6bca947b446 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 27 Jan 2025 21:58:16 -0800 Subject: [PATCH 19/75] partial cleanup & update the kmp --- tests/v1/spec_decode/test_ngram.py | 27 ++++------ vllm/v1/core/scheduler.py | 11 ++-- vllm/v1/engine/core.py | 35 ++++++------ vllm/v1/sample/rejection_sampler.py | 2 +- vllm/v1/spec_decode/ngram_proposer.py | 78 ++++++++++++++++----------- vllm/v1/worker/gpu_model_runner.py | 1 - 6 files changed, 80 insertions(+), 74 deletions(-) diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index f1c209b58155..8d365d80f180 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -19,20 +19,13 @@ def test_kmp_lps_array(proposer): def test_find_subarray_kmp(proposer): X = ConstantList([1, 2, 3, 4, 1, 2, 3, 5, 6]) - - assert proposer._find_subarray_kmp(X, [], 2) == [1, 2] - assert proposer._find_subarray_kmp(X, [7, 8], 1) is None - assert proposer._find_subarray_kmp(X, [1, 2, 3], 2) == [1, 2, 3, 4, 1] - assert proposer._find_subarray_kmp(X, [5, 6], 1) == [5, 6] - - -def test_propose(proposer): - context = ConstantList([1, 2, 3, 4, 1, 2, 3, 5, 6]) - assert proposer.propose(context, n=3, k=2) is None - assert proposer.propose(context, n=2, k=1) is None - assert proposer.propose(context, n=4, k=2) is None - - context = ConstantList([1, 2, 3, 4, 1, 2]) - assert proposer.propose(context, n=2, k=2) == [3, 4] - assert proposer.propose(context, n=2, k=1) == [3] - assert proposer.propose(context, n=3, k=2) is None + assert proposer._find_subarray_kmp(X, 2, 2) is None + X = ConstantList([1, 2, 3, 4, 1, 2, 3]) + assert proposer._find_subarray_kmp(X, 2, 3) == [4, 1, 2] + assert proposer._find_subarray_kmp(X, 2, 2) == [4, 1] + assert proposer._find_subarray_kmp(X, 1, 3) == [4, 1, 2] + assert proposer._find_subarray_kmp(X, 1, 2) == [4, 1] + X = ConstantList([1, 3, 6, 2, 3, 4, 1, 2, 3]) + assert proposer._find_subarray_kmp(X, 2, 3) == [4, 1, 2] + # Return on the first match + assert proposer._find_subarray_kmp(X, 1, 3) == [6, 2, 3] \ No newline at end of file diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index cfea256f09fd..4d1f773be19a 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -114,10 +114,11 @@ def schedule(self) -> "SchedulerOutput": token_budget = self.max_num_scheduled_tokens # Encoder-related. scheduled_encoder_inputs: Dict[str, List[int]] = {} + encoder_budget = self.max_num_encoder_input_tokens + # Spec Decode-related. spec_decode = False scheduled_spec_decode_tokens: Dict[str, List[int]] = {} - encoder_budget = self.max_num_encoder_input_tokens # First, schedule the RUNNING requests. # NOTE(woosuk): At most 1 request in the RUNNING queue is allowed to be @@ -190,6 +191,7 @@ def schedule(self) -> "SchedulerOutput": self.encoder_cache_manager.allocate(request, i) encoder_budget = new_encoder_budget + # Speculative decode related. if request.spec_token_ids: spec_decode = True scheduled_spec_decode_tokens[ @@ -486,10 +488,9 @@ def _crop_request(self, request: Request, num_total_token: int) -> None: request.crop(num_total_token) def _check_stop(self, request: Request) -> bool: - """ - Check if the request should be stopped. - The function should handle both single token generation or - multiple token generation (e.g., spec decode) per step. + """Check if the request should be stopped. + The function should handle both single token generation or + multiple token generation (e.g., spec decode) per step. """ if (request.num_tokens >= self.max_model_len or request.num_output_tokens >= request.max_tokens): diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index eae93f5c43ed..924fc2388b41 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -125,24 +125,6 @@ def abort_requests(self, request_ids: List[str]): self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED) - def propose_tokens(self): - assert self.scheduler.speculative_config is not None - for req in self.scheduler.running: - # Ignore requests that are doing chunked prefill. - if req.num_computed_tokens < req.num_tokens - 1: - print("**", req.num_computed_tokens, req.num_tokens) - continue - # Ignore requests that already have spec tokens. - if len(req.spec_token_ids) > 0: - continue - spec_tokens = self.proposer.propose( - req.all_token_ids, - self.scheduler.speculative_config.ngram_prompt_lookup_min, - self.scheduler.speculative_config.num_speculative_tokens, - ) - if spec_tokens: - req.append_spec_token_ids(spec_tokens) - def step(self) -> EngineCoreOutputs: """Schedule, execute, and make output.""" @@ -165,6 +147,23 @@ def shutdown(self): def profile(self, is_start: bool = True): self.model_executor.profile(is_start) + def propose_tokens(self): + assert self.scheduler.speculative_config is not None + for req in self.scheduler.running: + # Ignore requests that are doing chunked prefill. + if req.num_computed_tokens < req.num_tokens - 1: + continue + # Ignore requests that already have spec tokens. + if len(req.spec_token_ids) > 0: + continue + spec_tokens = self.proposer.propose( + req.all_token_ids, + self.scheduler.speculative_config.ngram_prompt_lookup_min, + self.scheduler.speculative_config.num_speculative_tokens, + ) + if spec_tokens: + req.append_spec_token_ids(spec_tokens) + class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 854317077641..32315d8a5f53 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -12,7 +12,7 @@ class RejectionSampler(nn.Module): def sample(self, logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> SamplerOutput: - # num_reqs x [num_specuated_tokens] + # num_reqs x [num_speculated_tokens] spec_token_ids = sampling_metadata.spec_token_ids # only argmax is supported for now output_token_ids_cpu = logits.argmax(dim=-1).view(-1).tolist() diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 95557cbb41c6..ede3f1954d27 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -8,6 +8,38 @@ class NgramProposer: def __init__(self): pass + def propose(self, context_token_ids: ConstantList[int], n: int, + k: int) -> Optional[List[int]]: + """Proposes the next sequence of tokens based on n-gram pattern + matching in the context. The function finds matches of the last n + tokens in the previous context, and returns k tokens that followed + that match. + + Args: + context_token_ids: List of token IDs representing the + context sequence. + n: Length of the n-gram to match. + k: Number of tokens follow the match. If there are less + than k tokens follow the match, we will return + the maximum amount of tokens until the end. + + Returns: + List[int]: The sequence of tokens that followed + the matched n-gram in the context. + None: If no matching n-gram pattern is found. + + Example: + If context_token_ids = [1,2,3,4,2,3], n = 2, and k = 4: + - The last 2 tokens [2,3] will be matched against the previous + 4 tokens [1,2,3,4]. + - Finding a match of [2,3] would return the tokens that + followed that pattern. Here we will return [4,2,3] because + we only have three tokens after the match. + """ + # TODO: Use c++ to implement the _find_subarray_kmp to + # improve the efficiency + return self._find_subarray_kmp(context_token_ids, n, k) + def _kmp_lps_array(self, pattern: List[int]) -> List[int]: """ Build the lps (longest proper prefix which is also suffix) @@ -31,36 +63,27 @@ def _kmp_lps_array(self, pattern: List[int]) -> List[int]: return lps - def _find_subarray_kmp(self, X: List[int], Y: List[int], - K: int) -> Optional[List[int]]: - """ - Returns the subarray starting at the first occurrence of Y in X, - plus K subsequent elements (if available). If not found, returns None. - """ - N = len(X) - M = len(Y) - - if M == 0: - # If Y is empty, - # let's define that it matches at index 0 - return X[:K] + def _find_subarray_kmp(self, context_token_ids: ConstantList[int], n: int, + k: int) -> Optional[List[int]]: + context_len = len(context_token_ids) + assert n > 0 + pattern = context_token_ids[-n:] # Precompute lps array for Y - lps = self._kmp_lps_array(Y) - - i = 0 # index for X - j = 0 # index for Y + lps = self._kmp_lps_array(pattern) - while i < N: - if X[i] == Y[j]: + i = 0 + j = 0 + # -n because the last n tokens are used as pattern + while i < context_len - n: + if context_token_ids[i] == pattern[j]: i += 1 j += 1 # If we have matched the entire Y - if j == M: - # Found Y in X, gather the next K elements - start_index = i - M # Where the match started - return X[start_index:start_index + M + K] + if j == n: + # Found pattern in context, gather the next K elements + return context_token_ids[i:i + k] else: # Mismatch if j != 0: @@ -71,12 +94,3 @@ def _find_subarray_kmp(self, X: List[int], Y: List[int], # Y not found return None - - def propose(self, context_token_ids: ConstantList[int], n: int, - k: int) -> Optional[List[int]]: - ngrams = context_token_ids[-n:] - lookup_tokens = context_token_ids[:-n] - match_tokens = self._find_subarray_kmp(lookup_tokens, ngrams, k) - if match_tokens is None: - return None - return match_tokens[n:] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f563000e59f3..d6300684c343 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -791,7 +791,6 @@ def execute_model( hidden_states = hidden_states[:num_scheduled_tokens] hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(hidden_states, None) - logger.info("logits: %s", str(logits.shape)) # Sample the next token and get logprobs if needed. sampling_metadata = self._prepare_sampling(scheduler_output) From eeab20476fc5f67ded171571d7b7b2499b361873 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 28 Jan 2025 10:58:17 -0800 Subject: [PATCH 20/75] minor --- tests/v1/core/test_stop_checking.py | 13 +++++++++++-- vllm/v1/worker/gpu_model_runner.py | 4 +--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/v1/core/test_stop_checking.py b/tests/v1/core/test_stop_checking.py index 4913d9b9ab6c..b2d9c21efe7b 100644 --- a/tests/v1/core/test_stop_checking.py +++ b/tests/v1/core/test_stop_checking.py @@ -59,10 +59,20 @@ def test_multiple_stop_tokens(scheduler): assert result is True assert request.status == RequestStatus.FINISHED_STOPPED assert request.request_id in scheduler.finished_req_ids - # Should be cropped at the first stop token assert len(request.output_token_ids) == 1 assert list(request.output_token_ids) == [EOS_TOKEN_ID] + # EOS token is generated in the middle of the output tokens + request = _create_test_request("test1", 100, stop_token_ids=[42, 43, 44]) + scheduler.requests[request.request_id] = request + request.append_output_token_ids([1, 2, 3, 4, 5, EOS_TOKEN_ID, 7, 43, 5]) + result = scheduler._check_stop(request) + assert result is True + assert request.status == RequestStatus.FINISHED_STOPPED + assert request.request_id in scheduler.finished_req_ids + assert len(request.output_token_ids) == 6 + assert list(request.output_token_ids) == [1, 2, 3, 4, 5, EOS_TOKEN_ID] + # Stop token, 43 is one of the stop tokens request = _create_test_request("test1", 100, stop_token_ids=[42, 43, 44]) scheduler.requests[request.request_id] = request @@ -88,6 +98,5 @@ def test_multiple_stop_tokens(scheduler): assert result is True assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED assert request.request_id in scheduler.finished_req_ids - # Should be cropped at the first stop token assert len(request.output_token_ids) == max_tokens assert list(request.output_token_ids) == output_token_ids[:max_tokens] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d6300684c343..36dbac8031ae 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -345,9 +345,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput") \ # TODO: The Python loop can be slow. Optimize. num_scheduled_tokens = [] max_num_scheduled_tokens = 0 - for i, req_id in enumerate(self.input_batch.req_ids): - if i == num_reqs: - break + for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): assert req_id is not None num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens.append(num_tokens) From 6772e079f0f0fc1df31cea64d53c458e3707ab55 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 28 Jan 2025 10:58:32 -0800 Subject: [PATCH 21/75] minor --- tests/v1/e2e/test_basic_specdecode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/e2e/test_basic_specdecode.py b/tests/v1/e2e/test_basic_specdecode.py index 1aecc6928e08..cdcdac574416 100644 --- a/tests/v1/e2e/test_basic_specdecode.py +++ b/tests/v1/e2e/test_basic_specdecode.py @@ -1,7 +1,7 @@ from vllm import LLM, SamplingParams prompts = [ - "The future of AI is", + "Can you repeat the sentence ten times, this is a sentence?", "This is a basic spec decode test", ] # Only support greedy for now From a5932a701449ddb5ac22d9a4d217a03b1008610b Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Fri, 31 Jan 2025 00:00:21 -0800 Subject: [PATCH 22/75] fix comments --- vllm/v1/core/scheduler.py | 39 ++++++++++++++----------------- vllm/v1/engine/core.py | 2 -- vllm/v1/request.py | 11 +++++++++ vllm/v1/worker/gpu_input_batch.py | 3 +++ 4 files changed, 32 insertions(+), 23 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 4d1f773be19a..5c05973ec789 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -478,19 +478,16 @@ def update_from_output( scheduler_stats=self.make_stats(), ) - def _crop_request(self, request: Request, num_total_token: int) -> None: - """Truncate the request to the num_total_token. - We do not need to update the input batch because - it will be updated in the next execute_model call's - _update_states method, where the request data is aligned - with the data in the persistent batch. - """ - request.crop(num_total_token) - def _check_stop(self, request: Request) -> bool: """Check if the request should be stopped. The function should handle both single token generation or multiple token generation (e.g., spec decode) per step. + + This function will crop requests to the given number of tokens. + When cropping, we do not need to update the input batch because + it will be updated in the next execute_model call's + _update_states method, where the request data is aligned + with the data in the persistent batch. """ if (request.num_tokens >= self.max_model_len or request.num_output_tokens >= request.max_tokens): @@ -499,30 +496,30 @@ def _check_stop(self, request: Request) -> bool: self.max_model_len, request.num_tokens, request.max_tokens + request.num_prompt_tokens, request.num_output_tokens + request.num_prompt_tokens) - self._crop_request(request, num_total_token) + request.crop(num_total_token) self._free_request(request) return True sampling_params = request.sampling_params - if (not sampling_params.ignore_eos - and request.eos_token_id in request.output_token_ids): + if not sampling_params.ignore_eos: assert request.eos_token_id is not None - request.status = RequestStatus.FINISHED_STOPPED - num_total_token = request.num_prompt_tokens + \ - request.output_token_ids.index(request.eos_token_id) + 1 - self._crop_request(request, num_total_token) - self._free_request(request) - return True + if request.eos_token_id in request.output_token_ids: + assert request.eos_token_id is not None + request.status = RequestStatus.FINISHED_STOPPED + num_total_token = request.num_prompt_tokens + \ + request.output_token_ids.index(request.eos_token_id) + 1 + request.crop(num_total_token) + self._free_request(request) + return True - stop_token_ids = set(sampling_params.stop_token_ids or set([])) output_token_ids = set(request.output_token_ids) - for stop_token_id in stop_token_ids: + for stop_token_id in sampling_params.stop_token_ids: if stop_token_id in output_token_ids: request.status = RequestStatus.FINISHED_STOPPED request.stop_reason = stop_token_id num_total_token = request.num_prompt_tokens + \ request.output_token_ids.index(stop_token_id) + 1 - self._crop_request(request, num_total_token) + request.crop(num_total_token) self._free_request(request) return True return False diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 924fc2388b41..976cbd8e15da 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -63,8 +63,6 @@ def __init__( lora_config=vllm_config.lora_config, speculative_config=vllm_config.speculative_config) - self._last_logging_time = time.time() - self.mm_input_mapper_server = MMInputMapperServer( vllm_config.model_config) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 1bea0b1cefa4..02f45c4a47c0 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -158,6 +158,17 @@ def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None: self._kv_block_hashes.append(block_hash) def crop(self, num_total_token: int) -> None: + """Crops the token sequences to a specified total length while + preserving prompt tokens. + + Args: + num_total_token: The desired total number of tokens after cropping. + + Raises: + ValueError: If num_total_token is less than the number of prompt + tokens, as prompt tokens cannot be cropped. + """ + if num_total_token < self.num_prompt_tokens: raise ValueError("Cannot crop the prompt tokens.") num_output_token = num_total_token - self.num_prompt_tokens diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index b86dd48ae3ac..058dbd2f7485 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -366,6 +366,9 @@ def make_sampling_metadata( if rejection_sampling: assert req_id_to_spec_token_ids is not None spec_token_ids.append(req_id_to_spec_token_ids[req_id]) + else: + assert req_id_to_spec_token_ids is None, \ + "spec_token_ids can only be set with rejection sampling" return SamplingMetadata( temperature=self.temperature[:self.num_reqs], From 5d3a31ae010e205e4db049f1c9c0cf281f333d75 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 3 Feb 2025 11:50:44 -0800 Subject: [PATCH 23/75] change sampled_token_ids to tensor --- vllm/v1/core/scheduler.py | 2 -- vllm/v1/outputs.py | 5 ++--- vllm/v1/sample/sampler.py | 2 +- vllm/v1/worker/gpu_input_batch.py | 5 ----- vllm/v1/worker/gpu_model_runner.py | 5 +++-- 5 files changed, 6 insertions(+), 13 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 052f953fd645..72b912ec9147 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -515,12 +515,10 @@ def _check_stop(self, request: Request) -> bool: sampling_params = request.sampling_params if not sampling_params.ignore_eos: assert request.eos_token_id is not None - # print(request.eos_token_id in request.output_token_ids) if request.eos_token_id in request.output_token_ids: request.status = RequestStatus.FINISHED_STOPPED num_total_token = request.num_prompt_tokens + \ request.output_token_ids.index(request.eos_token_id) + 1 - print("**", num_total_token) request.crop(num_total_token) self._free_request(request) return True diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index de2eeb6fff7f..48a50da88d6d 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -32,9 +32,8 @@ class ModelRunnerOutput: # req_id -> index req_id_to_index: Dict[str, int] - # num_reqs x [num_generated_tokens] - # num_generated_tokens might be different for each request. - sampled_token_ids: List[List[int]] + # num_reqs x [max_num_generated_tokens] + sampled_token_ids: torch.Tensor # [num_reqs, max_num_logprobs + 1] logprob_token_ids_cpu: Optional[torch.Tensor] diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 61f97ff7deee..2543864e0550 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -60,7 +60,7 @@ def forward( sampled = sampled.to(torch.int32) sampler_output = SamplerOutput( - sampled_token_ids=[[x] for x in sampled.tolist()], + sampled_token_ids=sampled.unsqueeze(-1), logprob_token_ids=topk_indices, logprobs=topk_logprobs, prompt_logprob_token_ids=None, diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index a77b5f16be71..0908227e0463 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -369,11 +369,6 @@ def make_sampling_metadata( assert req_id_to_spec_token_ids is not None assert len(req_id_to_spec_token_ids) > 0 spec_token_ids.append(req_id_to_spec_token_ids[req_id]) - else: - assert req_id_to_spec_token_ids is not None - assert len(req_id_to_spec_token_ids) == 0, \ - ("spec_token_ids can only be set with " - f"rejection sampling: {req_id_to_spec_token_ids}") return SamplingMetadata( temperature=self.temperature[:self.num_reqs], diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e87bf821fb1e..34bda7ad2709 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -834,7 +834,7 @@ def execute_model( if seq_len >= req_state.num_tokens: # We don't rewind the generator state for requests now # because spec decode only supports greedy decoding for now. - token_len = sampled_token_ids[i].shape[-1] + token_len = sampled_token_ids.shape[-1] self.input_batch.num_tokens[i] += token_len req_state.output_token_ids.extend([0] * token_len) request_seq_lens.append((i, req_state, seq_len)) @@ -855,8 +855,9 @@ def execute_model( # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. # Update with the actual token ids + sampled_token_ids = sampled_token_ids.tolist() for i, req_state, seq_len in request_seq_lens: - token_ids = sampler_output.sampled_token_ids[i].tolist() + token_ids = sampled_token_ids[i] spec_token_ids = spec_tokens.get(req_id or "", []) for j, token_id in enumerate(token_ids): self.input_batch.token_ids_cpu[i, From f7f4c24ae9e7fff378c6c1d711de18c600acc2bd Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 3 Feb 2025 12:35:19 -0800 Subject: [PATCH 24/75] minor --- vllm/v1/worker/gpu_model_runner.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 34bda7ad2709..3965377724f6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -31,6 +31,7 @@ KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -819,7 +820,6 @@ def execute_model( ) sampled_token_ids = sampler_output.sampled_token_ids - spec_tokens = scheduler_output.scheduled_spec_decode_tokens # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. num_reqs = self.input_batch.num_reqs @@ -834,10 +834,11 @@ def execute_model( if seq_len >= req_state.num_tokens: # We don't rewind the generator state for requests now # because spec decode only supports greedy decoding for now. - token_len = sampled_token_ids.shape[-1] - self.input_batch.num_tokens[i] += token_len - req_state.output_token_ids.extend([0] * token_len) - request_seq_lens.append((i, req_state, seq_len)) + gen_len = (sampled_token_ids[i] + != INVALID_TOKEN_ID).sum().item() + self.input_batch.num_tokens[i] += gen_len + req_state.output_token_ids.extend([0] * gen_len) + request_seq_lens.append((i, req_state, gen_len)) else: # Ignore the sampled token from the partial request. # Rewind the generator state as if the token was not sampled. @@ -856,16 +857,12 @@ def execute_model( # Move as many CPU operations as possible before this sync point. # Update with the actual token ids sampled_token_ids = sampled_token_ids.tolist() - for i, req_state, seq_len in request_seq_lens: + for i, req_state, gen_len in request_seq_lens: token_ids = sampled_token_ids[i] - spec_token_ids = spec_tokens.get(req_id or "", []) for j, token_id in enumerate(token_ids): - self.input_batch.token_ids_cpu[i, - seq_len - len(spec_token_ids) + - j] = token_id + self.input_batch.token_ids_cpu[i, -gen_len + j] = token_id - req_state.output_token_ids[-1 - len(spec_token_ids) + - j] = token_id + req_state.output_token_ids[-gen_len + j] = token_id if sampler_output.logprob_token_ids is None: logprob_token_ids = None From a1eecd3d2f05f1d4afd02917595021e48c88ede1 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 3 Feb 2025 12:40:20 -0800 Subject: [PATCH 25/75] remove double free --- vllm/v1/core/scheduler.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 72b912ec9147..baf8aa1e1ff7 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -509,7 +509,6 @@ def _check_stop(self, request: Request) -> bool: request.max_tokens + request.num_prompt_tokens, request.num_output_tokens + request.num_prompt_tokens) request.crop(num_total_token) - self._free_request(request) return True sampling_params = request.sampling_params @@ -520,7 +519,6 @@ def _check_stop(self, request: Request) -> bool: num_total_token = request.num_prompt_tokens + \ request.output_token_ids.index(request.eos_token_id) + 1 request.crop(num_total_token) - self._free_request(request) return True output_token_ids = set(request.output_token_ids) @@ -531,7 +529,6 @@ def _check_stop(self, request: Request) -> bool: num_total_token = request.num_prompt_tokens + \ request.output_token_ids.index(stop_token_id) + 1 request.crop(num_total_token) - self._free_request(request) return True return False From c843121b5b99242d9e0e83767d22d6ac674a23b1 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 3 Feb 2025 13:57:29 -0800 Subject: [PATCH 26/75] fix bug in input batch token id update --- vllm/v1/worker/gpu_model_runner.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3965377724f6..7300eb09386b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -823,7 +823,7 @@ def execute_model( # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. num_reqs = self.input_batch.num_reqs - request_seq_lens: List[Tuple[int, CachedRequestState, int]] = [] + request_seq_lens: List[Tuple[int, CachedRequestState, int, int]] = [] for i, req_id in enumerate(self.input_batch.req_ids): if i == num_reqs: break @@ -838,7 +838,7 @@ def execute_model( != INVALID_TOKEN_ID).sum().item() self.input_batch.num_tokens[i] += gen_len req_state.output_token_ids.extend([0] * gen_len) - request_seq_lens.append((i, req_state, gen_len)) + request_seq_lens.append((i, req_state, seq_len, gen_len)) else: # Ignore the sampled token from the partial request. # Rewind the generator state as if the token was not sampled. @@ -857,11 +857,11 @@ def execute_model( # Move as many CPU operations as possible before this sync point. # Update with the actual token ids sampled_token_ids = sampled_token_ids.tolist() - for i, req_state, gen_len in request_seq_lens: + for i, req_state, seq_len, gen_len in request_seq_lens: token_ids = sampled_token_ids[i] for j, token_id in enumerate(token_ids): - self.input_batch.token_ids_cpu[i, -gen_len + j] = token_id - + self.input_batch.token_ids_cpu[i, seq_len - gen_len + j + + 1] = token_id req_state.output_token_ids[-gen_len + j] = token_id if sampler_output.logprob_token_ids is None: From cdcace5e19c440b884490a06331f170d84ad9960 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 3 Feb 2025 14:08:41 -0800 Subject: [PATCH 27/75] constant list for spec tokens --- vllm/v1/core/scheduler.py | 5 +++-- vllm/v1/request.py | 4 ++-- vllm/v1/sample/metadata.py | 4 +++- vllm/v1/worker/gpu_input_batch.py | 7 ++++--- vllm/v1/worker/gpu_model_runner.py | 6 +++--- 5 files changed, 15 insertions(+), 11 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index baf8aa1e1ff7..f8634cda6ac6 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -16,6 +16,7 @@ from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus +from vllm.v1.utils import ConstantList if TYPE_CHECKING: from vllm.multimodal import MultiModalKwargs @@ -120,7 +121,7 @@ def schedule(self) -> "SchedulerOutput": # Spec Decode-related. spec_decode = False - scheduled_spec_decode_tokens: Dict[str, List[int]] = {} + scheduled_spec_decode_tokens: Dict[str, ConstantList[int]] = {} # First, schedule the RUNNING requests. # NOTE(woosuk): At most 1 request in the RUNNING queue is allowed to be @@ -675,7 +676,7 @@ class SchedulerOutput: total_num_scheduled_tokens: int scheduled_encoder_inputs: Dict[str, List[int]] use_spec_decode: bool - scheduled_spec_decode_tokens: Dict[str, List[int]] + scheduled_spec_decode_tokens: Dict[str, ConstantList[int]] num_common_prefix_blocks: int preempted_req_ids: Set[str] diff --git a/vllm/v1/request.py b/vllm/v1/request.py index e789c16f6448..3713fbe86ad0 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -111,8 +111,8 @@ def clear_spec_tokens(self) -> None: self._spec_token_ids = [] @property - def spec_token_ids(self) -> List[int]: - return self._spec_token_ids + def spec_token_ids(self) -> ConstantList[int]: + return ConstantList(self._spec_token_ids) @property def num_tokens(self) -> int: diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 1d0b5e134927..eb9f95ebb611 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -5,6 +5,8 @@ import torch +from vllm.v1.utils import ConstantList + @dataclass class SamplingMetadata: @@ -13,7 +15,7 @@ class SamplingMetadata: all_greedy: bool all_random: bool rejection_sampling: bool - spec_token_ids: List[List[int]] + spec_token_ids: List[ConstantList[int]] top_p: torch.Tensor top_k: torch.Tensor diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 0908227e0463..cafce28fb174 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -11,6 +11,7 @@ from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.utils import ConstantList from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: @@ -326,7 +327,8 @@ def make_sampling_metadata( req_id_output_token_ids: Dict[str, List[int]], skip_copy: bool = False, rejection_sampling: bool = False, - req_id_to_spec_token_ids: Optional[Dict[str, List[int]]] = None, + req_id_to_spec_token_ids: Optional[Dict[str, + ConstantList[int]]] = None, ) -> SamplingMetadata: if not skip_copy: self.temperature[:self.num_reqs].copy_( @@ -354,7 +356,7 @@ def make_sampling_metadata( self.prompt_token_ids = self._make_prompt_token_ids_tensor() output_token_ids: List[List[int]] = [] - spec_token_ids: List[List[int]] = [] + spec_token_ids: List[ConstantList[int]] = [] for req_id in self.req_ids[:self.num_reqs]: assert req_id is not None # Currently we create a tensor for output_token_ids from scratch @@ -367,7 +369,6 @@ def make_sampling_metadata( output_token_ids.append(req_id_output_token_ids[req_id]) if rejection_sampling: assert req_id_to_spec_token_ids is not None - assert len(req_id_to_spec_token_ids) > 0 spec_token_ids.append(req_id_to_spec_token_ids[req_id]) return SamplingMetadata( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7300eb09386b..613e1a862b16 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -32,7 +32,7 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID -from vllm.v1.utils import bind_kv_cache +from vllm.v1.utils import ConstantList, bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch if TYPE_CHECKING: @@ -403,8 +403,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput") \ req_id] num_compute_tokens = self.input_batch.num_computed_tokens_cpu[i] spec_query_end_loc += req_num_scheduled_tokens - spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( - req_id, []) + spec_token_ids: ConstantList = scheduler_output.\ + scheduled_spec_decode_tokens.get(req_id, ConstantList([])) for j, spec_token_id in enumerate(spec_token_ids): # +1 here because the input for verification is # [last_output_token_id] + spec_token_ids From 2cab6e6745c14eac7d19f8b8d094e6aa2986359c Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 3 Feb 2025 14:23:01 -0800 Subject: [PATCH 28/75] header --- tests/v1/core/test_stop_checking.py | 1 + tests/v1/e2e/test_basic_specdecode.py | 1 + tests/v1/sample/test_rejection_sampler.py | 1 + tests/v1/spec_decode/test_ngram.py | 1 + vllm/v1/sample/rejection_sampler.py | 1 + vllm/v1/spec_decode/ngram_proposer.py | 1 + 6 files changed, 6 insertions(+) diff --git a/tests/v1/core/test_stop_checking.py b/tests/v1/core/test_stop_checking.py index 088f3fe11ec0..192ca276edf2 100644 --- a/tests/v1/core/test_stop_checking.py +++ b/tests/v1/core/test_stop_checking.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 from typing import List import pytest diff --git a/tests/v1/e2e/test_basic_specdecode.py b/tests/v1/e2e/test_basic_specdecode.py index cdcdac574416..2e3f19afb99d 100644 --- a/tests/v1/e2e/test_basic_specdecode.py +++ b/tests/v1/e2e/test_basic_specdecode.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 from vllm import LLM, SamplingParams prompts = [ diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 7f8d4395c6f3..4107b4c0dbe3 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 from typing import List import pytest diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index 8d365d80f180..ec663c84d0d2 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 import pytest from vllm.v1.spec_decode.ngram_proposer import NgramProposer diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 92413111506f..4412e18f8fe6 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 import torch import torch.nn as nn diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index ede3f1954d27..4a7e34cf865f 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 from typing import List, Optional from vllm.v1.utils import ConstantList From e30bc0cd1d14d700669975c9f298d2e68fa0c123 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 3 Feb 2025 16:41:19 -0800 Subject: [PATCH 29/75] bug fix for invalid token id check --- vllm/v1/worker/gpu_model_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 613e1a862b16..f9f9801fed55 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -858,7 +858,10 @@ def execute_model( # Update with the actual token ids sampled_token_ids = sampled_token_ids.tolist() for i, req_state, seq_len, gen_len in request_seq_lens: - token_ids = sampled_token_ids[i] + token_ids = [ + x for x in sampled_token_ids[i] if x != INVALID_TOKEN_ID + ] + sampled_token_ids[i] = token_ids for j, token_id in enumerate(token_ids): self.input_batch.token_ids_cpu[i, seq_len - gen_len + j + 1] = token_id From 18fba42050b3fbd9accb9d2b47e9d6a5c742e546 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 3 Feb 2025 16:49:17 -0800 Subject: [PATCH 30/75] type --- vllm/v1/outputs.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 48a50da88d6d..890b74501df2 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -32,8 +32,11 @@ class ModelRunnerOutput: # req_id -> index req_id_to_index: Dict[str, int] - # num_reqs x [max_num_generated_tokens] - sampled_token_ids: torch.Tensor + # num_reqs x num_generated_tokens + # num_generated_tokens is the number of tokens + # generated in the current step. It can be different for + # each request. + sampled_token_ids: List[List[int]] # [num_reqs, max_num_logprobs + 1] logprob_token_ids_cpu: Optional[torch.Tensor] From ba1d0fdd84fc4ba14f3d7310d8163bc6444635d9 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 3 Feb 2025 17:39:54 -0800 Subject: [PATCH 31/75] prefix caching + sd --- vllm/v1/core/kv_cache_manager.py | 42 +++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index de349ec12099..972b4e6f6a4b 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -122,7 +122,8 @@ def allocate_slots( self, request: Request, num_tokens: int, - new_computed_blocks: Optional[List[KVCacheBlock]] = None + new_computed_blocks: Optional[List[KVCacheBlock]] = None, + max_speculative_tokens: Optional[int] = None, ) -> Optional[List[KVCacheBlock]]: """Add slots for a request with new tokens to append. @@ -212,22 +213,39 @@ def allocate_slots( if not self.enable_caching: return new_blocks - # NOTE(rickyx): We are assuming the `num_tokens` are actual - # tokens rather than lookahead slots (e.g. for speculative decoding). - # TODO(rickyx): When supporting speculative decoding, we will need to - # differentiate between them so that we can know how many blocks are - # full after appending the actual tokens. - num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size - num_computed_full_blocks = num_computed_tokens // self.block_size - new_full_blocks = req_blocks[num_computed_full_blocks:num_full_blocks] + # We subtract max_speculative_tokens from num_computed_tokens to + # get the least number of tokens in the KV cache. + # All tokens before this number are guaranteed to be valid and + # stored in the cache. + max_speculative_tokens = max_speculative_tokens or 0 + min_num_last_step_computed_tokens = request.num_computed_tokens \ + - max_speculative_tokens + min_num_last_step_computed_full_blocks = \ + min_num_last_step_computed_tokens // self.block_size + + # Calculate the total number of complete blocks needed after appending + # new valid tokens. + # num_tokens_wo_spec_tokens = 1 represents a single new token. + # Here, speculated tokens generated in the last step are counted in + # request.num_computed_tokens. Specualted tokens in the current step + # are not counted and cached because they are not verified/accepted it. + num_tokens_wo_spec_tokens = 1 + num_full_blocks_after_append = ( + request.num_computed_tokens + + num_tokens_wo_spec_tokens) // self.block_size + + new_full_blocks = req_blocks[min_num_last_step_computed_full_blocks: + num_full_blocks_after_append] if new_full_blocks: self._cache_full_blocks( request=request, - blk_start_idx=num_computed_full_blocks, + blk_start_idx=min_num_last_step_computed_full_blocks, # The new full blocks are the full blocks that are not computed. full_blocks=new_full_blocks, - prev_block=(req_blocks[num_computed_full_blocks - 1] - if num_computed_full_blocks > 0 else None)) + prev_block=(req_blocks[min_num_last_step_computed_full_blocks - + 1] + if min_num_last_step_computed_full_blocks > 0 else + None)) return new_blocks From fc699536be57c2fdc9ee8bbcfb73013a54f64598 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 3 Feb 2025 18:14:49 -0800 Subject: [PATCH 32/75] pass in max_spec_num --- vllm/v1/core/kv_cache_manager.py | 33 ++++++++++++++++++-------------- vllm/v1/core/scheduler.py | 6 +++++- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 972b4e6f6a4b..5d9ff4b838f3 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -213,23 +213,28 @@ def allocate_slots( if not self.enable_caching: return new_blocks - # We subtract max_speculative_tokens from num_computed_tokens to - # get the least number of tokens in the KV cache. - # All tokens before this number are guaranteed to be valid and - # stored in the cache. - max_speculative_tokens = max_speculative_tokens or 0 - min_num_last_step_computed_tokens = request.num_computed_tokens \ + # Calculate the total number of complete blocks needed after appending + # new valid tokens. + if max_speculative_tokens is None: # prefill phase + num_tokens_wo_spec_tokens = num_tokens + min_num_last_step_computed_tokens = request.num_computed_tokens + else: # decoding phase + assert max_speculative_tokens >= 0 + # We subtract max_speculative_tokens from num_computed_tokens to + # get the least number of tokens in the KV cache. + # All tokens before this number are guaranteed to be valid and + # stored in the cache. + min_num_last_step_computed_tokens = request.num_computed_tokens \ - max_speculative_tokens + # num_tokens_wo_spec_tokens = 1 represents a single new token. + # Here, speculated tokens generated in the last step are counted in + # request.num_computed_tokens. Specualted tokens in the current step + # are not counted and cached because they are not verified/accepted + # it. + num_tokens_wo_spec_tokens = 1 + min_num_last_step_computed_full_blocks = \ min_num_last_step_computed_tokens // self.block_size - - # Calculate the total number of complete blocks needed after appending - # new valid tokens. - # num_tokens_wo_spec_tokens = 1 represents a single new token. - # Here, speculated tokens generated in the last step are counted in - # request.num_computed_tokens. Specualted tokens in the current step - # are not counted and cached because they are not verified/accepted it. - num_tokens_wo_spec_tokens = 1 num_full_blocks_after_append = ( request.num_computed_tokens + num_tokens_wo_spec_tokens) // self.block_size diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index f8634cda6ac6..31c87ee0da8a 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -131,6 +131,8 @@ def schedule(self) -> "SchedulerOutput": # TODO(woosuk): Remove this constraint after refactoring model runner. has_partial_request = False req_index = 0 + spec_lens = [len(req.spec_token_ids) for req in self.running] + max_speculative_tokens = max(spec_lens) if spec_lens else 0 while req_index < len(self.running): # Only the last request in the RUNNING queue can be "partial". assert not has_partial_request @@ -151,7 +153,9 @@ def schedule(self) -> "SchedulerOutput": while True: new_blocks = self.kv_cache_manager.allocate_slots( - request, num_new_tokens) + request, + num_new_tokens, + max_speculative_tokens=max_speculative_tokens) if new_blocks is None: # The request cannot be scheduled. # Preempt the lowest-priority request. From 970a91a382c3e262582fa4d98b8fbe1b26509b8e Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 3 Feb 2025 23:01:56 -0800 Subject: [PATCH 33/75] fix block calcaulation --- vllm/v1/core/kv_cache_manager.py | 36 +++++++++++++++----------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 5d9ff4b838f3..84dae980b0b5 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -215,32 +215,30 @@ def allocate_slots( # Calculate the total number of complete blocks needed after appending # new valid tokens. - if max_speculative_tokens is None: # prefill phase - num_tokens_wo_spec_tokens = num_tokens - min_num_last_step_computed_tokens = request.num_computed_tokens - else: # decoding phase - assert max_speculative_tokens >= 0 - # We subtract max_speculative_tokens from num_computed_tokens to - # get the least number of tokens in the KV cache. - # All tokens before this number are guaranteed to be valid and - # stored in the cache. - min_num_last_step_computed_tokens = request.num_computed_tokens \ - - max_speculative_tokens - # num_tokens_wo_spec_tokens = 1 represents a single new token. - # Here, speculated tokens generated in the last step are counted in - # request.num_computed_tokens. Specualted tokens in the current step - # are not counted and cached because they are not verified/accepted - # it. - num_tokens_wo_spec_tokens = 1 + max_speculative_tokens = max_speculative_tokens or 0 + + # We subtract max_speculative_tokens from num_computed_tokens to + # get the least number of tokens in the KV cache. + # All tokens before this number are guaranteed to be valid and + # stored in the cache. + min_num_last_step_computed_tokens = num_computed_tokens \ + - max_speculative_tokens + # Speculated tokens generated in the last step are counted in + # request.num_computed_tokens. Specualted tokens in the current step + # are not counted and cached because they are not verified/accepted + # yet. + num_tokens_wo_spec_tokens = num_tokens - max_speculative_tokens min_num_last_step_computed_full_blocks = \ min_num_last_step_computed_tokens // self.block_size num_full_blocks_after_append = ( - request.num_computed_tokens + - num_tokens_wo_spec_tokens) // self.block_size + num_computed_tokens + num_tokens_wo_spec_tokens) // self.block_size new_full_blocks = req_blocks[min_num_last_step_computed_full_blocks: num_full_blocks_after_append] + print(num_tokens, max_speculative_tokens, + min_num_last_step_computed_full_blocks, + num_full_blocks_after_append) if new_full_blocks: self._cache_full_blocks( request=request, From 7ecb66819435be9c544342543c6bc03741e43174 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 3 Feb 2025 23:02:38 -0800 Subject: [PATCH 34/75] minor --- vllm/v1/core/kv_cache_manager.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 84dae980b0b5..d7a271941e21 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -236,9 +236,7 @@ def allocate_slots( new_full_blocks = req_blocks[min_num_last_step_computed_full_blocks: num_full_blocks_after_append] - print(num_tokens, max_speculative_tokens, - min_num_last_step_computed_full_blocks, - num_full_blocks_after_append) + if new_full_blocks: self._cache_full_blocks( request=request, From acda9237c7a309ab73944de1d32cf284599492fb Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 4 Feb 2025 22:11:44 -0800 Subject: [PATCH 35/75] fix comments --- tests/v1/core/test_scheduler.py | 72 ++++++++++++++++++++++++++- tests/v1/e2e/test_basic_specdecode.py | 36 -------------- tests/v1/e2e/test_ngram_specdecode.py | 49 ++++++++++++++++++ vllm/v1/core/kv_cache_manager.py | 3 ++ vllm/v1/core/scheduler.py | 7 +-- vllm/v1/outputs.py | 2 +- vllm/v1/sample/sampler.py | 4 ++ vllm/v1/worker/gpu_input_batch.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 4 +- 9 files changed, 134 insertions(+), 45 deletions(-) delete mode 100644 tests/v1/e2e/test_basic_specdecode.py create mode 100644 tests/v1/e2e/test_ngram_specdecode.py diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 8eb08f3e842c..b98c51d9bbc4 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 from typing import List, Optional +import torch + from vllm.config import CacheConfig, ModelConfig, SchedulerConfig from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams @@ -8,6 +10,8 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus +EOS_TOKEN_ID = 50256 + def create_scheduler( model: str = "facebook/opt-125m", @@ -45,8 +49,12 @@ def create_requests( num_requests: int, num_tokens: int = 10, mm_positions: Optional[List[PlaceholderRange]] = None, + max_tokens: int = 16, + stop_token_ids: Optional[List[int]] = None, ): - sampling_params = SamplingParams() + sampling_params = SamplingParams(ignore_eos=True, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids) requests = [] for i in range(num_requests): if mm_positions is not None: @@ -194,7 +202,7 @@ def test_schedule_partial_requests(): model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, - sampled_token_ids=[0] * len(requests), + sampled_token_ids=torch.tensor([[0]] * len(requests)), logprob_token_ids_cpu=None, logprobs_cpu=None, ) @@ -212,3 +220,63 @@ def test_schedule_partial_requests(): assert output.num_scheduled_tokens[requests[0].request_id] == 1 assert output.num_scheduled_tokens[requests[1].request_id] == 700 assert requests[2].request_id not in output.num_scheduled_tokens + + +def test_multiple_stop_tokens(): + """Test with stop when generating multiple tokens""" + scheduler = create_scheduler() + # Nonstop case + request = create_requests(max_tokens=100, stop_token_ids=[42, 43, 44]) + scheduler.requests[request.request_id] = request + request.append_output_token_ids([4, 5, 6, 7, 8]) + result = scheduler._check_stop(request) + assert result is False + + # EOS token is generated in the beginning of the output tokens + request = create_requests(max_tokens=100, stop_token_ids=[42, 43, 44]) + scheduler.requests[request.request_id] = request + request.append_output_token_ids([EOS_TOKEN_ID, 5, EOS_TOKEN_ID, 7, 43, 5]) + result = scheduler._check_stop(request) + assert result is True + assert request.status == RequestStatus.FINISHED_STOPPED + assert request.request_id in scheduler.finished_req_ids + assert len(request.output_token_ids) == 1 + assert list(request.output_token_ids) == [EOS_TOKEN_ID] + + # EOS token is generated in the middle of the output tokens + request = create_requests(max_tokens=100, stop_token_ids=[42, 43, 44]) + scheduler.requests[request.request_id] = request + request.append_output_token_ids([1, 2, 3, 4, 5, EOS_TOKEN_ID, 7, 43, 5]) + result = scheduler._check_stop(request) + assert result is True + assert request.status == RequestStatus.FINISHED_STOPPED + assert request.request_id in scheduler.finished_req_ids + assert len(request.output_token_ids) == 6 + assert list(request.output_token_ids) == [1, 2, 3, 4, 5, EOS_TOKEN_ID] + + # Stop token, 43 is one of the stop tokens + request = create_requests(max_tokens=100, stop_token_ids=[42, 43, 44]) + scheduler.requests[request.request_id] = request + request.append_output_token_ids([4, 5, 43, 7, 43, 5]) + result = scheduler._check_stop(request) + assert result is True + assert request.status == RequestStatus.FINISHED_STOPPED + assert request.stop_reason == 43 + assert request.request_id in scheduler.finished_req_ids + # Should be cropped at the first stop token + assert len(request.output_token_ids) == 3 + assert list(request.output_token_ids) == [4, 5, 43] + + # Max tokens, should be cropped when reaching the max tokens + max_tokens = 2 + request = create_requests(max_tokens=max_tokens, + stop_token_ids=[42, 43, 44]) + scheduler.requests[request.request_id] = request + output_token_ids = [4, 5, 43, 7, 43, 5] + request.append_output_token_ids(output_token_ids) + result = scheduler._check_stop(request) + assert result is True + assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED + assert request.request_id in scheduler.finished_req_ids + assert len(request.output_token_ids) == max_tokens + assert list(request.output_token_ids) == output_token_ids[:max_tokens] diff --git a/tests/v1/e2e/test_basic_specdecode.py b/tests/v1/e2e/test_basic_specdecode.py deleted file mode 100644 index 2e3f19afb99d..000000000000 --- a/tests/v1/e2e/test_basic_specdecode.py +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -from vllm import LLM, SamplingParams - -prompts = [ - "Can you repeat the sentence ten times, this is a sentence?", - "This is a basic spec decode test", -] -# Only support greedy for now -sampling_params = SamplingParams(temperature=0) - - -def test_basic_specdecode(monkeypatch): - ''' - Compare the outputs of a original LLM and a speculative LLM - should be the same. - ''' - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - model = "meta-llama/Meta-Llama-3-8B-Instruct" - - ref_llm = LLM(model=model) - ref_outputs = ref_llm.generate(prompts, sampling_params) - del ref_llm - # print(ref_outputs.outputs[0].text) - - spec_llm = LLM(model=model, - speculative_model='[ngram]', - ngram_prompt_lookup_max=5, - ngram_prompt_lookup_min=3, - num_speculative_tokens=3) - spec_outputs = spec_llm.generate(prompts, sampling_params) - for ref_output, spec_output in zip(ref_outputs, spec_outputs): - assert ref_output.outputs[0].text == spec_output.outputs[0].text, \ - (f"ref_output: {ref_output.outputs[0].text}," - f"spec_output: {spec_output.outputs[0].text}") - del spec_llm diff --git a/tests/v1/e2e/test_ngram_specdecode.py b/tests/v1/e2e/test_ngram_specdecode.py new file mode 100644 index 000000000000..a4750eb62075 --- /dev/null +++ b/tests/v1/e2e/test_ngram_specdecode.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest + +from vllm import LLM, SamplingParams + + +@pytest.fixture +def test_prompts(): + return [ + "Can you repeat the sentence ten times, this is a sentence?", + "This is a basic spec decode test", + ] + + +@pytest.fixture +def sampling_config(): + # Only support greedy for now + return SamplingParams(temperature=0, max_tokens=100, ignore_eos=False) + + +@pytest.fixture +def model_name(): + return "meta-llama/Meta-Llama-3-8B-Instruct" + + +def test_ngram_correctness(monkeypatch, test_prompts, sampling_config, + model_name): + ''' + Compare the outputs of a original LLM and a speculative LLM + should be the same when using ngram speculative decoding. + ''' + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + ref_llm = LLM(model=model_name) + ref_outputs = ref_llm.generate(test_prompts, sampling_config) + del ref_llm + + spec_llm = LLM(model=model_name, + speculative_model='[ngram]', + ngram_prompt_lookup_max=5, + ngram_prompt_lookup_min=3, + num_speculative_tokens=3) + spec_outputs = spec_llm.generate(test_prompts, sampling_config) + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + assert ref_output.outputs[0].text == spec_output.outputs[0].text, \ + (f"ref_output: {ref_output.outputs[0].text}," + f"spec_output: {spec_output.outputs[0].text}") + del spec_llm diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index d7a271941e21..c17adaaf8e30 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -133,6 +133,9 @@ def allocate_slots( not include the tokens that have already been computed. new_computed_blocks: A list of new computed blocks just hitting the prefix caching. + max_speculative_tokens: The maximum number of speculative tokens, + used to calculate the minimum number of full blocks that are + cached. Blocks layout: ----------------------------------------------------------------------- diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 5ec06ed585bd..3d4b8bb8709a 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -493,7 +493,7 @@ def update_from_output( # Check for stop and update request state. # This must be called before we make the EngineCoreOutput. - stopped = self._check_stop(request) + stopped = self._maybe_stop_and_crop(request) num_new_tokens = request.num_tokens - num_tokens_before_step if stopped: self._free_request(request) @@ -518,12 +518,13 @@ def update_from_output( scheduler_stats=self.make_stats(), ) - def _check_stop(self, request: Request) -> bool: + def _maybe_stop_and_crop(self, request: Request) -> bool: """Check if the request should be stopped. The function should handle both single token generation or multiple token generation (e.g., spec decode) per step. - This function will crop requests to the given number of tokens. + This function will crop requests because the request is stopped + in the middle of the generation. When cropping, we do not need to update the input batch because it will be updated in the next execute_model call's _update_states method, where the request data is aligned diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 890b74501df2..b2110c0e2d84 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -35,7 +35,7 @@ class ModelRunnerOutput: # num_reqs x num_generated_tokens # num_generated_tokens is the number of tokens # generated in the current step. It can be different for - # each request. + # each request due to speculative/jump decoding. sampled_token_ids: List[List[int]] # [num_reqs, max_num_logprobs + 1] diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 2543864e0550..2bc41624a65b 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -28,6 +28,10 @@ def forward( sampling_metadata: SamplingMetadata, ) -> SamplerOutput: if sampling_metadata.rejection_sampling: + needs_logprobs = sampling_metadata.max_num_logprobs > 0 + if needs_logprobs: + raise NotImplementedError( + "Rejection sampling does not support logprobs.") return self.rejection_sampler.sample( logits, sampling_metadata, diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 727649f67fa6..ae7d28318074 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -369,7 +369,7 @@ def make_sampling_metadata( # statistics. output_token_ids.append(req_id_output_token_ids[req_id]) req_spec_token_ids = req_id_to_spec_token_ids.get(req_id, None) - if req_spec_token_ids is not None: + if req_spec_token_ids is not None and len(req_spec_token_ids) > 0: spec_token_ids.append(req_spec_token_ids) # If any of the requests require speculative decoding, set the # flag to True. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7290fb2deac5..e98c14bfaaa8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -851,8 +851,8 @@ def execute_model( break assert req_id is not None req_state = self.requests[req_id] - seq_len = req_state.num_computed_tokens + \ - scheduler_output.num_scheduled_tokens[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) if seq_len >= req_state.num_tokens: # We don't rewind the generator state for requests now # because spec decode only supports greedy decoding for now. From 2006c7593f5acb0c0f987944354565c3678f06e2 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 4 Feb 2025 22:14:42 -0800 Subject: [PATCH 36/75] fix test --- tests/v1/core/test_scheduler.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index b98c51d9bbc4..7c5c2d9af586 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -226,14 +226,18 @@ def test_multiple_stop_tokens(): """Test with stop when generating multiple tokens""" scheduler = create_scheduler() # Nonstop case - request = create_requests(max_tokens=100, stop_token_ids=[42, 43, 44]) + request = create_requests(num_request=1, + max_tokens=100, + stop_token_ids=[42, 43, 44]) scheduler.requests[request.request_id] = request request.append_output_token_ids([4, 5, 6, 7, 8]) result = scheduler._check_stop(request) assert result is False # EOS token is generated in the beginning of the output tokens - request = create_requests(max_tokens=100, stop_token_ids=[42, 43, 44]) + request = create_requests(num_requests=1, + max_tokens=100, + stop_token_ids=[42, 43, 44]) scheduler.requests[request.request_id] = request request.append_output_token_ids([EOS_TOKEN_ID, 5, EOS_TOKEN_ID, 7, 43, 5]) result = scheduler._check_stop(request) @@ -244,7 +248,9 @@ def test_multiple_stop_tokens(): assert list(request.output_token_ids) == [EOS_TOKEN_ID] # EOS token is generated in the middle of the output tokens - request = create_requests(max_tokens=100, stop_token_ids=[42, 43, 44]) + request = create_requests(num_requests=1, + max_tokens=100, + stop_token_ids=[42, 43, 44]) scheduler.requests[request.request_id] = request request.append_output_token_ids([1, 2, 3, 4, 5, EOS_TOKEN_ID, 7, 43, 5]) result = scheduler._check_stop(request) @@ -255,7 +261,9 @@ def test_multiple_stop_tokens(): assert list(request.output_token_ids) == [1, 2, 3, 4, 5, EOS_TOKEN_ID] # Stop token, 43 is one of the stop tokens - request = create_requests(max_tokens=100, stop_token_ids=[42, 43, 44]) + request = create_requests(num_requests=1, + max_tokens=100, + stop_token_ids=[42, 43, 44]) scheduler.requests[request.request_id] = request request.append_output_token_ids([4, 5, 43, 7, 43, 5]) result = scheduler._check_stop(request) @@ -269,7 +277,8 @@ def test_multiple_stop_tokens(): # Max tokens, should be cropped when reaching the max tokens max_tokens = 2 - request = create_requests(max_tokens=max_tokens, + request = create_requests(num_requests=1, + max_tokens=max_tokens, stop_token_ids=[42, 43, 44]) scheduler.requests[request.request_id] = request output_token_ids = [4, 5, 43, 7, 43, 5] From 2ad4f3989000aa10cd3342ac650715def09075fd Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 4 Feb 2025 22:15:23 -0800 Subject: [PATCH 37/75] fix test --- tests/v1/core/test_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 7c5c2d9af586..a118b3f009da 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -226,7 +226,7 @@ def test_multiple_stop_tokens(): """Test with stop when generating multiple tokens""" scheduler = create_scheduler() # Nonstop case - request = create_requests(num_request=1, + request = create_requests(num_requests=1, max_tokens=100, stop_token_ids=[42, 43, 44]) scheduler.requests[request.request_id] = request From faafcb63085073af42198b9f37bb1eabc5bf9d37 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 4 Feb 2025 22:22:38 -0800 Subject: [PATCH 38/75] fix test --- tests/v1/core/test_scheduler.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index a118b3f009da..d711b60fe1f9 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -52,7 +52,7 @@ def create_requests( max_tokens: int = 16, stop_token_ids: Optional[List[int]] = None, ): - sampling_params = SamplingParams(ignore_eos=True, + sampling_params = SamplingParams(ignore_eos=False, max_tokens=max_tokens, stop_token_ids=stop_token_ids) requests = [] @@ -71,7 +71,7 @@ def create_requests( multi_modal_inputs=mm_inputs, multi_modal_placeholders=mm_position, multi_modal_hashes=None, - eos_token_id=None, + eos_token_id=EOS_TOKEN_ID, arrival_time=0, ) requests.append(request) @@ -228,49 +228,46 @@ def test_multiple_stop_tokens(): # Nonstop case request = create_requests(num_requests=1, max_tokens=100, - stop_token_ids=[42, 43, 44]) + stop_token_ids=[42, 43, 44])[0] scheduler.requests[request.request_id] = request request.append_output_token_ids([4, 5, 6, 7, 8]) - result = scheduler._check_stop(request) + result = scheduler._maybe_stop_and_crop(request) assert result is False # EOS token is generated in the beginning of the output tokens request = create_requests(num_requests=1, max_tokens=100, - stop_token_ids=[42, 43, 44]) + stop_token_ids=[42, 43, 44])[0] scheduler.requests[request.request_id] = request request.append_output_token_ids([EOS_TOKEN_ID, 5, EOS_TOKEN_ID, 7, 43, 5]) - result = scheduler._check_stop(request) + result = scheduler._maybe_stop_and_crop(request) assert result is True assert request.status == RequestStatus.FINISHED_STOPPED - assert request.request_id in scheduler.finished_req_ids assert len(request.output_token_ids) == 1 assert list(request.output_token_ids) == [EOS_TOKEN_ID] # EOS token is generated in the middle of the output tokens request = create_requests(num_requests=1, max_tokens=100, - stop_token_ids=[42, 43, 44]) + stop_token_ids=[42, 43, 44])[0] scheduler.requests[request.request_id] = request request.append_output_token_ids([1, 2, 3, 4, 5, EOS_TOKEN_ID, 7, 43, 5]) - result = scheduler._check_stop(request) + result = scheduler._maybe_stop_and_crop(request) assert result is True assert request.status == RequestStatus.FINISHED_STOPPED - assert request.request_id in scheduler.finished_req_ids assert len(request.output_token_ids) == 6 assert list(request.output_token_ids) == [1, 2, 3, 4, 5, EOS_TOKEN_ID] # Stop token, 43 is one of the stop tokens request = create_requests(num_requests=1, max_tokens=100, - stop_token_ids=[42, 43, 44]) + stop_token_ids=[42, 43, 44])[0] scheduler.requests[request.request_id] = request request.append_output_token_ids([4, 5, 43, 7, 43, 5]) - result = scheduler._check_stop(request) + result = scheduler._maybe_stop_and_crop(request) assert result is True assert request.status == RequestStatus.FINISHED_STOPPED assert request.stop_reason == 43 - assert request.request_id in scheduler.finished_req_ids # Should be cropped at the first stop token assert len(request.output_token_ids) == 3 assert list(request.output_token_ids) == [4, 5, 43] @@ -279,13 +276,12 @@ def test_multiple_stop_tokens(): max_tokens = 2 request = create_requests(num_requests=1, max_tokens=max_tokens, - stop_token_ids=[42, 43, 44]) + stop_token_ids=[42, 43, 44])[0] scheduler.requests[request.request_id] = request output_token_ids = [4, 5, 43, 7, 43, 5] request.append_output_token_ids(output_token_ids) - result = scheduler._check_stop(request) + result = scheduler._maybe_stop_and_crop(request) assert result is True assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED - assert request.request_id in scheduler.finished_req_ids assert len(request.output_token_ids) == max_tokens assert list(request.output_token_ids) == output_token_ids[:max_tokens] From 038e20339cc70476a5c90b4bac6ca3e41e85fe13 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 4 Feb 2025 22:31:09 -0800 Subject: [PATCH 39/75] merge test --- tests/v1/core/test_stop_checking.py | 105 ---------------------------- 1 file changed, 105 deletions(-) delete mode 100644 tests/v1/core/test_stop_checking.py diff --git a/tests/v1/core/test_stop_checking.py b/tests/v1/core/test_stop_checking.py deleted file mode 100644 index 192ca276edf2..000000000000 --- a/tests/v1/core/test_stop_checking.py +++ /dev/null @@ -1,105 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -from typing import List - -import pytest - -from vllm.config import CacheConfig, ModelConfig, SchedulerConfig -from vllm.v1.core.scheduler import Scheduler -from vllm.v1.request import Request, RequestStatus, SamplingParams - -EOS_TOKEN_ID = 50256 - - -@pytest.fixture -def scheduler(): - cache_config = CacheConfig(block_size=16, - gpu_memory_utilization=0.9, - swap_space=0.1, - cache_dtype="auto") - cache_config.num_gpu_blocks = 100 - return Scheduler(scheduler_config=SchedulerConfig(), - model_config=ModelConfig(model="facebook/opt-125m", - task="auto", - tokenizer="test_tokenizer", - tokenizer_mode="auto", - trust_remote_code=False, - dtype="float16", - seed=42), - cache_config=cache_config, - lora_config=None) - - -def _create_test_request(request_id: str, max_tokens: int, - stop_token_ids: List[int]) -> Request: - return Request(request_id=request_id, - prompt="test prompt", - prompt_token_ids=[1, 2, 3], - multi_modal_inputs=None, - multi_modal_hashes=None, - multi_modal_placeholders=None, - sampling_params=SamplingParams( - max_tokens=max_tokens, - stop_token_ids=stop_token_ids, - ignore_eos=False), - eos_token_id=EOS_TOKEN_ID, - arrival_time=0.0) - - -def test_multiple_stop_tokens(scheduler): - """Test with stop when generating multiple tokens""" - # Nonstop case - request = _create_test_request("test1", 100, stop_token_ids=[42, 43, 44]) - scheduler.requests[request.request_id] = request - request.append_output_token_ids([4, 5, 6, 7, 8]) - result = scheduler._check_stop(request) - assert result is False - - # EOS token is generated in the beginning of the output tokens - request = _create_test_request("test1", 100, stop_token_ids=[42, 43, 44]) - scheduler.requests[request.request_id] = request - request.append_output_token_ids([EOS_TOKEN_ID, 5, EOS_TOKEN_ID, 7, 43, 5]) - result = scheduler._check_stop(request) - assert result is True - assert request.status == RequestStatus.FINISHED_STOPPED - assert request.request_id in scheduler.finished_req_ids - assert len(request.output_token_ids) == 1 - assert list(request.output_token_ids) == [EOS_TOKEN_ID] - - # EOS token is generated in the middle of the output tokens - request = _create_test_request("test1", 100, stop_token_ids=[42, 43, 44]) - scheduler.requests[request.request_id] = request - request.append_output_token_ids([1, 2, 3, 4, 5, EOS_TOKEN_ID, 7, 43, 5]) - result = scheduler._check_stop(request) - assert result is True - assert request.status == RequestStatus.FINISHED_STOPPED - assert request.request_id in scheduler.finished_req_ids - assert len(request.output_token_ids) == 6 - assert list(request.output_token_ids) == [1, 2, 3, 4, 5, EOS_TOKEN_ID] - - # Stop token, 43 is one of the stop tokens - request = _create_test_request("test1", 100, stop_token_ids=[42, 43, 44]) - scheduler.requests[request.request_id] = request - request.append_output_token_ids([4, 5, 43, 7, 43, 5]) - result = scheduler._check_stop(request) - assert result is True - assert request.status == RequestStatus.FINISHED_STOPPED - assert request.stop_reason == 43 - assert request.request_id in scheduler.finished_req_ids - # Should be cropped at the first stop token - assert len(request.output_token_ids) == 3 - assert list(request.output_token_ids) == [4, 5, 43] - - # Max tokens, should be cropped when reaching the max tokens - max_tokens = 2 - request = _create_test_request("test2", - max_tokens, - stop_token_ids=[42, 43, 44]) - scheduler.requests[request.request_id] = request - output_token_ids = [4, 5, 43, 7, 43, 5] - request.append_output_token_ids(output_token_ids) - result = scheduler._check_stop(request) - assert result is True - assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED - assert request.request_id in scheduler.finished_req_ids - assert len(request.output_token_ids) == max_tokens - assert list(request.output_token_ids) == output_token_ids[:max_tokens] From 4dc0f878ba0f80d18ec33c7929c4e450911adae5 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Fri, 7 Feb 2025 09:45:47 -0800 Subject: [PATCH 40/75] stop checking --- vllm/v1/core/scheduler.py | 83 ++++++++++----------------- vllm/v1/outputs.py | 2 +- vllm/v1/request.py | 4 +- vllm/v1/sample/metadata.py | 4 +- vllm/v1/sample/sampler.py | 2 +- vllm/v1/spec_decode/ngram_proposer.py | 8 ++- vllm/v1/worker/gpu_input_batch.py | 10 ++-- vllm/v1/worker/gpu_model_runner.py | 22 ++++--- 8 files changed, 55 insertions(+), 80 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 3d4b8bb8709a..81600d88f764 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -16,7 +16,6 @@ from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus -from vllm.v1.utils import ConstantList if TYPE_CHECKING: from vllm.multimodal import MultiModalKwargs @@ -33,7 +32,7 @@ def __init__( model_config: ModelConfig, cache_config: CacheConfig, lora_config: Optional[LoRAConfig], - speculative_config: Optional[SpeculativeConfig] = None, + speculative_config: Optional[SpeculativeConfig], ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config @@ -121,7 +120,7 @@ def schedule(self) -> "SchedulerOutput": # Spec Decode-related. spec_decode = False - scheduled_spec_decode_tokens: Dict[str, ConstantList[int]] = {} + scheduled_spec_decode_tokens: Dict[str, List[int]] = {} # First, schedule the RUNNING requests. req_index = 0 @@ -129,8 +128,8 @@ def schedule(self) -> "SchedulerOutput": max_speculative_tokens = max(spec_lens) if spec_lens else 0 while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - num_new_tokens = request.num_tokens_with_spec \ - - request.num_computed_tokens + num_new_tokens = (request.num_tokens_with_spec - + request.num_computed_tokens) num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 @@ -486,22 +485,27 @@ def update_from_output( request, input_id) if request.num_computed_tokens >= request.num_tokens: + # We assume all spec tokens are verified + # if we perform speculative decoding for this request. + # Therefore, we can clear all spec tokens after + # the generation step. request.clear_spec_tokens() - num_tokens_before_step = request.num_tokens - request.append_output_token_ids(token_ids) - # TODO: Update the KV cache manager for prefix caching. - - # Check for stop and update request state. - # This must be called before we make the EngineCoreOutput. - stopped = self._maybe_stop_and_crop(request) - num_new_tokens = request.num_tokens - num_tokens_before_step - if stopped: - self._free_request(request) + + new_token_ids = [] + for output_token_id in token_ids: + request.append_output_token_ids(token_ids) + new_token_ids.append(output_token_id) + + stopped = self._check_stop(request, output_token_id) + # This must be called before we make the EngineCoreOutput. + if stopped: + self._free_request(request) + break # Add EngineCoreOutput for this Request. output = EngineCoreOutput( request_id=req_id, - new_token_ids=request.output_token_ids[-num_new_tokens:], + new_token_ids=new_token_ids, finished=request.is_finished(), finish_reason=request.get_finished_reason(), stop_reason=request.stop_reason) @@ -518,47 +522,22 @@ def update_from_output( scheduler_stats=self.make_stats(), ) - def _maybe_stop_and_crop(self, request: Request) -> bool: - """Check if the request should be stopped. - The function should handle both single token generation or - multiple token generation (e.g., spec decode) per step. - - This function will crop requests because the request is stopped - in the middle of the generation. - When cropping, we do not need to update the input batch because - it will be updated in the next execute_model call's - _update_states method, where the request data is aligned - with the data in the persistent batch. - """ + def _check_stop(self, request: Request, last_token_id: int) -> bool: if (request.num_tokens >= self.max_model_len or request.num_output_tokens >= request.max_tokens): request.status = RequestStatus.FINISHED_LENGTH_CAPPED - num_total_token = min( - self.max_model_len, request.num_tokens, - request.max_tokens + request.num_prompt_tokens, - request.num_output_tokens + request.num_prompt_tokens) - request.crop(num_total_token) return True sampling_params = request.sampling_params - if not sampling_params.ignore_eos: - assert request.eos_token_id is not None - if request.eos_token_id in request.output_token_ids: - request.status = RequestStatus.FINISHED_STOPPED - num_total_token = request.num_prompt_tokens + \ - request.output_token_ids.index(request.eos_token_id) + 1 - request.crop(num_total_token) - return True - - output_token_ids = set(request.output_token_ids) - for stop_token_id in sampling_params.stop_token_ids: - if stop_token_id in output_token_ids: - request.status = RequestStatus.FINISHED_STOPPED - request.stop_reason = stop_token_id - num_total_token = request.num_prompt_tokens + \ - request.output_token_ids.index(stop_token_id) + 1 - request.crop(num_total_token) - return True + if (not sampling_params.ignore_eos + and last_token_id == request.eos_token_id): + request.status = RequestStatus.FINISHED_STOPPED + return True + + if last_token_id in (sampling_params.stop_token_ids or ()): + request.status = RequestStatus.FINISHED_STOPPED + request.stop_reason = last_token_id + return True return False def add_request(self, request: Request) -> None: @@ -688,7 +667,7 @@ class SchedulerOutput: total_num_scheduled_tokens: int scheduled_encoder_inputs: Dict[str, List[int]] use_spec_decode: bool - scheduled_spec_decode_tokens: Dict[str, ConstantList[int]] + scheduled_spec_decode_tokens: Dict[str, List[int]] num_common_prefix_blocks: int finished_req_ids: Set[str] diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index b2110c0e2d84..45b203f8c79e 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -9,7 +9,7 @@ @dataclass class SamplerOutput: - # num_reqs x [max_num_generated_tokens] + # [num_reqs, max_num_generated_tokens] sampled_token_ids: torch.Tensor # [num_reqs, max_num_logprobs + 1] diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 3713fbe86ad0..e789c16f6448 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -111,8 +111,8 @@ def clear_spec_tokens(self) -> None: self._spec_token_ids = [] @property - def spec_token_ids(self) -> ConstantList[int]: - return ConstantList(self._spec_token_ids) + def spec_token_ids(self) -> List[int]: + return self._spec_token_ids @property def num_tokens(self) -> int: diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index eb9f95ebb611..1d0b5e134927 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -5,8 +5,6 @@ import torch -from vllm.v1.utils import ConstantList - @dataclass class SamplingMetadata: @@ -15,7 +13,7 @@ class SamplingMetadata: all_greedy: bool all_random: bool rejection_sampling: bool - spec_token_ids: List[ConstantList[int]] + spec_token_ids: List[List[int]] top_p: torch.Tensor top_k: torch.Tensor diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 2bc41624a65b..7621cf569876 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -27,8 +27,8 @@ def forward( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: + needs_logprobs = sampling_metadata.max_num_logprobs > 0 if sampling_metadata.rejection_sampling: - needs_logprobs = sampling_metadata.max_num_logprobs > 0 if needs_logprobs: raise NotImplementedError( "Rejection sampling does not support logprobs.") diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 4a7e34cf865f..8eee99506b1f 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -41,7 +41,8 @@ def propose(self, context_token_ids: ConstantList[int], n: int, # improve the efficiency return self._find_subarray_kmp(context_token_ids, n, k) - def _kmp_lps_array(self, pattern: List[int]) -> List[int]: + @staticmethod + def _kmp_lps_array(pattern: List[int]) -> List[int]: """ Build the lps (longest proper prefix which is also suffix) array for the pattern. @@ -64,14 +65,15 @@ def _kmp_lps_array(self, pattern: List[int]) -> List[int]: return lps - def _find_subarray_kmp(self, context_token_ids: ConstantList[int], n: int, + @staticmethod + def _find_subarray_kmp(context_token_ids: ConstantList[int], n: int, k: int) -> Optional[List[int]]: context_len = len(context_token_ids) assert n > 0 pattern = context_token_ids[-n:] # Precompute lps array for Y - lps = self._kmp_lps_array(pattern) + lps = NgramProposer._kmp_lps_array(pattern) i = 0 j = 0 diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index ae7d28318074..040f321f85ca 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -11,7 +11,6 @@ from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.utils import ConstantList from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: @@ -325,8 +324,7 @@ def condense(self, empty_req_indices: List[int]) -> None: def make_sampling_metadata( self, req_id_output_token_ids: Dict[str, List[int]], - req_id_to_spec_token_ids: Optional[Dict[str, - ConstantList[int]]] = None, + req_id_to_spec_token_ids: Optional[Dict[str, List[int]]] = None, skip_copy: bool = False, ) -> SamplingMetadata: if not skip_copy: @@ -355,7 +353,7 @@ def make_sampling_metadata( self.prompt_token_ids = self._make_prompt_token_ids_tensor() output_token_ids: List[List[int]] = [] - spec_token_ids: List[ConstantList[int]] = [] + spec_token_ids: List[List[int]] = [] rejection_sampling = False req_id_to_spec_token_ids = req_id_to_spec_token_ids or {} for req_id in self.req_ids[:self.num_reqs]: @@ -368,8 +366,8 @@ def make_sampling_metadata( # TODO - Replace this with incremental update to output token # statistics. output_token_ids.append(req_id_output_token_ids[req_id]) - req_spec_token_ids = req_id_to_spec_token_ids.get(req_id, None) - if req_spec_token_ids is not None and len(req_spec_token_ids) > 0: + req_spec_token_ids = req_id_to_spec_token_ids.get(req_id) + if req_spec_token_ids: spec_token_ids.append(req_spec_token_ids) # If any of the requests require speculative decoding, set the # flag to True. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e98c14bfaaa8..506372db971a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -32,7 +32,7 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID -from vllm.v1.utils import ConstantList, bind_kv_cache +from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch if TYPE_CHECKING: @@ -376,7 +376,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput") \ # TODO: The Python loop can be slow. Optimize. num_scheduled_tokens = [] max_num_scheduled_tokens = 0 - for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): assert req_id is not None num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens.append(num_tokens) @@ -424,16 +424,14 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput") \ # Get spec decode logits indices. spec_query_end_loc = 0 spec_decode_logits_indices: List[int] = [] - for i, req_id in enumerate(self.input_batch.req_ids): - if i == num_reqs: - break + for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): assert req_id is not None req_num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] num_compute_tokens = self.input_batch.num_computed_tokens_cpu[i] spec_query_end_loc += req_num_scheduled_tokens - spec_token_ids: ConstantList = scheduler_output.\ - scheduled_spec_decode_tokens.get(req_id, ConstantList([])) + spec_token_ids = ( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) for j, spec_token_id in enumerate(spec_token_ids): # +1 here because the input for verification is # [last_output_token_id] + spec_token_ids @@ -598,10 +596,10 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput") \ logits_indices = torch.tensor(spec_decode_logits_indices, device=self.device) else: - # NOTE(woosuk): Due to chunked prefills, there can be at most 1 - # partial request in the batch. While we should not sample any - # token from this partial request, we do so for simplicity. - # We will ignore the sampled token from the partial request. + # NOTE(woosuk): Due to chunked prefills, the batch may contain + # partial requests. While we should not sample any token + # from these partial requests, we do so for simplicity. + # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. logits_indices = query_start_loc[1:] - 1 return attn_metadata, logits_indices @@ -664,7 +662,7 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): def _prepare_sampling( self, batch_changed: bool, - req_id_spec_token_ids: Dict[str, ConstantList[int]], + req_id_spec_token_ids: Dict[str, List[int]], ) -> SamplingMetadata: # Create the sampling metadata. req_id_output_token_ids: Dict[str, List[int]] = \ From 54508d56a2b630bb22d18ecf94bc1aa0c48df280 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Fri, 7 Feb 2025 22:44:52 -0800 Subject: [PATCH 41/75] kv cache manager --- vllm/v1/core/kv_cache_manager.py | 50 ++++++++++++-------------------- vllm/v1/core/scheduler.py | 17 +++++------ vllm/v1/outputs.py | 3 ++ vllm/v1/request.py | 20 ------------- 4 files changed, 29 insertions(+), 61 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index c17adaaf8e30..8ff5655e3df8 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -72,6 +72,11 @@ def __init__( self.req_to_blocks: DefaultDict[str, List[KVCacheBlock]] = defaultdict(list) + # {req_id: The number of cached blocks for this given request} + # This is used to track the number of cached blocks for each request, + # currently only used for speculative decoding. + self.cached_block_num: Dict[str, int] = defaultdict(int) + @property def usage(self) -> float: return 1.0 - (self.free_block_queue.num_free_blocks / @@ -122,8 +127,7 @@ def allocate_slots( self, request: Request, num_tokens: int, - new_computed_blocks: Optional[List[KVCacheBlock]] = None, - max_speculative_tokens: Optional[int] = None, + new_computed_blocks: Optional[List[KVCacheBlock]] = None ) -> Optional[List[KVCacheBlock]]: """Add slots for a request with new tokens to append. @@ -216,41 +220,23 @@ def allocate_slots( if not self.enable_caching: return new_blocks - # Calculate the total number of complete blocks needed after appending - # new valid tokens. - max_speculative_tokens = max_speculative_tokens or 0 - - # We subtract max_speculative_tokens from num_computed_tokens to - # get the least number of tokens in the KV cache. - # All tokens before this number are guaranteed to be valid and - # stored in the cache. - min_num_last_step_computed_tokens = num_computed_tokens \ - - max_speculative_tokens - # Speculated tokens generated in the last step are counted in - # request.num_computed_tokens. Specualted tokens in the current step - # are not counted and cached because they are not verified/accepted - # yet. - num_tokens_wo_spec_tokens = num_tokens - max_speculative_tokens - - min_num_last_step_computed_full_blocks = \ - min_num_last_step_computed_tokens // self.block_size - num_full_blocks_after_append = ( - num_computed_tokens + num_tokens_wo_spec_tokens) // self.block_size - - new_full_blocks = req_blocks[min_num_last_step_computed_full_blocks: - num_full_blocks_after_append] + num_cached_blocks = self.cached_block_num[ + request.request_id] // self.block_size + num_full_blocks_after_append = (num_computed_tokens + num_tokens - len( + request.spec_token_ids)) // self.block_size + new_full_blocks = req_blocks[ + num_cached_blocks:num_full_blocks_after_append] if new_full_blocks: self._cache_full_blocks( request=request, - blk_start_idx=min_num_last_step_computed_full_blocks, + blk_start_idx=num_cached_blocks, # The new full blocks are the full blocks that are not computed. full_blocks=new_full_blocks, - prev_block=(req_blocks[min_num_last_step_computed_full_blocks - - 1] - if min_num_last_step_computed_full_blocks > 0 else - None)) - + prev_block=(req_blocks[num_cached_blocks - + 1] if num_cached_blocks > 0 else None)) + self.cached_block_num[ + request.request_id] = num_full_blocks_after_append return new_blocks def free(self, request: Request) -> None: @@ -274,6 +260,8 @@ def free(self, request: Request) -> None: if block.ref_cnt == 0: self.free_block_queue.append(block) + self.cached_block_num.pop(request.request_id) + def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF flows to invalid prefix caching after the weights are updated, diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 81600d88f764..98bada83993a 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -98,13 +98,14 @@ def __init__( def schedule(self) -> "SchedulerOutput": # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. - # Each request just has the num_computed_tokens and num_tokens. - # num_tokens = len(prompt_token_ids) + len(output_token_ids) + - # len(spec_token_ids). + # Each request just has the num_computed_tokens and + # num_tokens_with_spec. num_tokens_with_spec = + # len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids). # At each step, the scheduler tries to assign tokens to the requests # so that each request's num_computed_tokens can catch up its - # num_tokens. This is general enough to cover chunked prefills, - # prefix caching, and the "jump decoding" optimization in the future. + # num_tokens_with_spec. This is general enough to cover + # chunked prefills, prefix caching, speculative decoding, + # and the "jump decoding" optimization in the future. scheduled_new_reqs: List[Request] = [] scheduled_resumed_reqs: List[Request] = [] @@ -124,8 +125,6 @@ def schedule(self) -> "SchedulerOutput": # First, schedule the RUNNING requests. req_index = 0 - spec_lens = [len(req.spec_token_ids) for req in self.running] - max_speculative_tokens = max(spec_lens) if spec_lens else 0 while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] num_new_tokens = (request.num_tokens_with_spec - @@ -150,9 +149,7 @@ def schedule(self) -> "SchedulerOutput": while True: new_blocks = self.kv_cache_manager.allocate_slots( - request, - num_new_tokens, - max_speculative_tokens=max_speculative_tokens) + request, num_new_tokens) if new_blocks is None: # The request cannot be scheduled. # Preempt the lowest-priority request. diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 45b203f8c79e..88d350d67563 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -10,6 +10,9 @@ class SamplerOutput: # [num_reqs, max_num_generated_tokens] + # Different requests can have different number of generated tokens. + # All requests are padded to max_num_generated_tokens. + # INVALID_TOKEN_ID is used for padding. sampled_token_ids: torch.Tensor # [num_reqs, max_num_logprobs + 1] diff --git a/vllm/v1/request.py b/vllm/v1/request.py index e789c16f6448..a2eb56413b58 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -151,26 +151,6 @@ def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None: def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None: self._kv_block_hashes.append(block_hash) - def crop(self, num_total_token: int) -> None: - """Crops the token sequences to a specified total length while - preserving prompt tokens. - - Args: - num_total_token: The desired total number of tokens after cropping. - - Raises: - ValueError: If num_total_token is less than the number of prompt - tokens, as prompt tokens cannot be cropped. - """ - - if num_total_token < self.num_prompt_tokens: - raise ValueError("Cannot crop the prompt tokens.") - num_output_token = num_total_token - self.num_prompt_tokens - self._output_token_ids = self._output_token_ids[:num_output_token] - self._all_token_ids = self._all_token_ids[:num_total_token] - self.output_token_ids = ConstantList(self._output_token_ids) - self.all_token_ids = ConstantList(self._all_token_ids) - class RequestStatus(enum.IntEnum): """Status of a request.""" From c25b9eb5116bd4465180bde6de07321c5a4dd266 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Fri, 7 Feb 2025 23:32:06 -0800 Subject: [PATCH 42/75] bug fix --- tests/v1/worker/test_gpu_input_batch.py | 2 +- vllm/v1/core/kv_cache_manager.py | 3 -- vllm/v1/core/scheduler.py | 51 ++++++++++++++----------- vllm/v1/worker/gpu_input_batch.py | 7 ++-- 4 files changed, 32 insertions(+), 31 deletions(-) diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index ceb67005055e..786d604eb772 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -192,7 +192,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): # Generate the sampling metadata sampling_metadata = input_batch.make_sampling_metadata( - req_id_output_token_ids, skip_copy=False) + req_id_output_token_ids, req_id_to_spec_token_ids={}, skip_copy=False) # Create expected output. expected_sampling_metadata = _construct_expected_sampling_metadata( diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 8ff5655e3df8..0945e3243bdc 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -137,9 +137,6 @@ def allocate_slots( not include the tokens that have already been computed. new_computed_blocks: A list of new computed blocks just hitting the prefix caching. - max_speculative_tokens: The maximum number of speculative tokens, - used to calculate the minimum number of full blocks that are - cached. Blocks layout: ----------------------------------------------------------------------- diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 98bada83993a..2d6eb63d40a6 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -446,27 +446,32 @@ def update_from_output( continue req_index = model_runner_output.req_id_to_index[req_id] - token_ids = sampled_token_ids[req_index] - # num_computed_tokens_step is the number of tokens computed - # in the current step. - # num_computed_tokens_step = - # num_scheduled_tokens - num_tokens_rejected, - # where num_tokens_rejected = - # len(request.spec_token_ids) + 1 - len(token_ids). - # We use this way of calculating num_computed_tokens_step because of - # chunked prefill. In chunked prefill, number of computed tokens - # is not equal to the number of generated/sampled tokens. - # Here, len(request.spec_token_ids) + 1 is the maximum number of - # tokens generated in the current step, - # len(request.spec_token_ids) + 1 - len(token_ids) is the number of - # tokens rejected in the current step. - num_computed_tokens_step = num_scheduled_tokens[req_id] - ( - len(request.spec_token_ids) + 1 - len(token_ids)) - request.num_computed_tokens += num_computed_tokens_step - # When the request's num_computed_tokens catches up its num_tokens, - # the request generates output tokens. Otherwise, we ignore the - # sampler output for the request. - assert request.num_computed_tokens <= request.num_tokens + spec_token_ids = request.spec_token_ids + generated_token_ids = sampled_token_ids[req_index] + if not spec_token_ids: + # When the request's num_computed_tokens catches up + # its num_tokens, the request generates output tokens. + # Otherwise, we ignore the sampler output for the request. + request.num_computed_tokens += num_tokens_scheduled + assert request.num_computed_tokens <= request.num_tokens + else: + # num_computed_tokens_step is the number of tokens computed + # in the current step. + # num_computed_tokens_step = + # num_scheduled_tokens - num_tokens_rejected, + # where num_tokens_rejected = + # len(request.spec_token_ids) + 1 - len(generated_token_ids). + # We use this way of calculating num_computed_tokens_step + # because of chunked prefill. In chunked prefill, number of + # computed tokens is not equal to the number of + # generated/sampled tokens. Here, len(request.spec_token_ids) + # + 1 is the maximum number of tokens generated in the current + # step, len(request.spec_token_ids) + 1 - + # len(generated_token_ids) is the number of tokens rejected + # in the current step. + num_computed_tokens_step = num_scheduled_tokens[req_id] - ( + len(request.spec_token_ids) + 1 - len(generated_token_ids)) + request.num_computed_tokens += num_computed_tokens_step cached_encoder_input_ids = ( self.encoder_cache_manager.get_cached_input_ids(request)) @@ -489,8 +494,8 @@ def update_from_output( request.clear_spec_tokens() new_token_ids = [] - for output_token_id in token_ids: - request.append_output_token_ids(token_ids) + for output_token_id in generated_token_ids: + request.append_output_token_ids(output_token_id) new_token_ids.append(output_token_id) stopped = self._check_stop(request, output_token_id) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 040f321f85ca..f23964157ebe 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -324,7 +324,7 @@ def condense(self, empty_req_indices: List[int]) -> None: def make_sampling_metadata( self, req_id_output_token_ids: Dict[str, List[int]], - req_id_to_spec_token_ids: Optional[Dict[str, List[int]]] = None, + req_id_to_spec_token_ids: Dict[str, List[int]], skip_copy: bool = False, ) -> SamplingMetadata: if not skip_copy: @@ -355,7 +355,6 @@ def make_sampling_metadata( output_token_ids: List[List[int]] = [] spec_token_ids: List[List[int]] = [] rejection_sampling = False - req_id_to_spec_token_ids = req_id_to_spec_token_ids or {} for req_id in self.req_ids[:self.num_reqs]: assert req_id is not None # Currently we create a tensor for output_token_ids from scratch @@ -366,9 +365,9 @@ def make_sampling_metadata( # TODO - Replace this with incremental update to output token # statistics. output_token_ids.append(req_id_output_token_ids[req_id]) - req_spec_token_ids = req_id_to_spec_token_ids.get(req_id) + req_spec_token_ids = req_id_to_spec_token_ids.get(req_id, []) + spec_token_ids.append(req_spec_token_ids) if req_spec_token_ids: - spec_token_ids.append(req_spec_token_ids) # If any of the requests require speculative decoding, set the # flag to True. rejection_sampling = True From f4ee865e6a5a0b9c3e3644e1c767925648905ef4 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 8 Feb 2025 12:18:17 -0800 Subject: [PATCH 43/75] fix scheduler --- tests/v1/core/test_scheduler.py | 5 ++--- vllm/v1/core/kv_cache_manager.py | 2 +- vllm/v1/core/scheduler.py | 18 +++--------------- 3 files changed, 6 insertions(+), 19 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index f3c10bb3d52a..bded13596126 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from typing import List, Optional -import torch - from vllm.config import CacheConfig, ModelConfig, SchedulerConfig from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams @@ -42,6 +40,7 @@ def create_scheduler( return Scheduler(scheduler_config, model_config, cache_config, + speculative_config=None, lora_config=None) @@ -202,7 +201,7 @@ def test_schedule_partial_requests(): model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, - sampled_token_ids=torch.tensor([[0]] * len(requests)), + sampled_token_ids=[[0] for _ in range(len(requests))], logprobs=None, prompt_logprobs_dict={}, ) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index c3ffd659cceb..67ad19eada16 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -263,7 +263,7 @@ def free(self, request: Request) -> None: if block.ref_cnt == 0: self.free_block_queue.append(block) - self.cached_block_num.pop(request.request_id) + self.cached_block_num.pop(request.request_id, None) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index b3b89d3cc45b..ae2a1d617b11 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -453,7 +453,6 @@ def update_from_output( scheduler_output: "SchedulerOutput", model_runner_output: "ModelRunnerOutput", ) -> EngineCoreOutputs: - # NOTE(woosuk): This method doesn't consider speculative decoding. sampled_token_ids = model_runner_output.sampled_token_ids logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict @@ -474,9 +473,8 @@ def update_from_output( continue req_index = model_runner_output.req_id_to_index[req_id] - spec_token_ids = request.spec_token_ids generated_token_ids = sampled_token_ids[req_index] - if not spec_token_ids: + if not request.spec_token_ids: # When the request's num_computed_tokens catches up # its num_tokens, the request generates output tokens. # Otherwise, we ignore the sampler output for the request. @@ -526,18 +524,9 @@ def update_from_output( stopped = False new_logprobs = None - new_token_ids = None + new_token_ids = [] if request.num_computed_tokens == request.num_tokens: - req_index = model_runner_output.req_id_to_index[req_id] - # NOTE(woosuk): Currently, we assume that each request - # generates at most one token at each step. - token_id = sampled_token_ids[req_index] - request.append_output_token_ids(token_id) - num_new_tokens = 1 - # TODO: Update the KV cache manager for prefix caching. - - new_token_ids = [] for output_token_id in generated_token_ids: request.append_output_token_ids(output_token_id) new_token_ids.append(output_token_id) @@ -551,12 +540,11 @@ def update_from_output( # Extract sample logprobs if needed. if request.sampling_params.logprobs is not None: assert logprobs is not None + req_index = model_runner_output.req_id_to_index[req_id] # NOTE: once we support N tokens per step (spec decode), # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) - new_token_ids = request.output_token_ids[-num_new_tokens:] - # Transmit partial if chunked prefill & prompt logprobs is enabled if new_token_ids or prompt_logprobs_tensors is not None: # Add EngineCoreOutput for this Request. From d02844a7866c26255c11f68056e44a5bb1664a84 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 8 Feb 2025 13:08:04 -0800 Subject: [PATCH 44/75] fix scheduler and tests --- tests/v1/core/test_scheduler.py | 244 +++++++++++++++++++++++--------- vllm/v1/core/scheduler.py | 18 ++- 2 files changed, 192 insertions(+), 70 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index bded13596126..0782a9d9dc5c 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -4,7 +4,7 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams -from vllm.v1.core.scheduler import Scheduler +from vllm.v1.core.scheduler import Scheduler, SchedulerOutput from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -221,66 +221,184 @@ def test_schedule_partial_requests(): assert requests[2].request_id not in output.num_scheduled_tokens -def test_multiple_stop_tokens(): - """Test with stop when generating multiple tokens""" +def test_stop_via_update_from_output(): + """Test stopping behavior through update_from_output""" scheduler = create_scheduler() - # Nonstop case - request = create_requests(num_requests=1, - max_tokens=100, - stop_token_ids=[42, 43, 44])[0] - scheduler.requests[request.request_id] = request - request.append_output_token_ids([4, 5, 6, 7, 8]) - result = scheduler._maybe_stop_and_crop(request) - assert result is False - - # EOS token is generated in the beginning of the output tokens - request = create_requests(num_requests=1, - max_tokens=100, - stop_token_ids=[42, 43, 44])[0] - scheduler.requests[request.request_id] = request - request.append_output_token_ids([EOS_TOKEN_ID, 5, EOS_TOKEN_ID, 7, 43, 5]) - result = scheduler._maybe_stop_and_crop(request) - assert result is True - assert request.status == RequestStatus.FINISHED_STOPPED - assert len(request.output_token_ids) == 1 - assert list(request.output_token_ids) == [EOS_TOKEN_ID] - - # EOS token is generated in the middle of the output tokens - request = create_requests(num_requests=1, - max_tokens=100, - stop_token_ids=[42, 43, 44])[0] - scheduler.requests[request.request_id] = request - request.append_output_token_ids([1, 2, 3, 4, 5, EOS_TOKEN_ID, 7, 43, 5]) - result = scheduler._maybe_stop_and_crop(request) - assert result is True - assert request.status == RequestStatus.FINISHED_STOPPED - assert len(request.output_token_ids) == 6 - assert list(request.output_token_ids) == [1, 2, 3, 4, 5, EOS_TOKEN_ID] - - # Stop token, 43 is one of the stop tokens - request = create_requests(num_requests=1, - max_tokens=100, - stop_token_ids=[42, 43, 44])[0] - scheduler.requests[request.request_id] = request - request.append_output_token_ids([4, 5, 43, 7, 43, 5]) - result = scheduler._maybe_stop_and_crop(request) - assert result is True - assert request.status == RequestStatus.FINISHED_STOPPED - assert request.stop_reason == 43 - # Should be cropped at the first stop token - assert len(request.output_token_ids) == 3 - assert list(request.output_token_ids) == [4, 5, 43] - - # Max tokens, should be cropped when reaching the max tokens - max_tokens = 2 - request = create_requests(num_requests=1, - max_tokens=max_tokens, - stop_token_ids=[42, 43, 44])[0] - scheduler.requests[request.request_id] = request - output_token_ids = [4, 5, 43, 7, 43, 5] - request.append_output_token_ids(output_token_ids) - result = scheduler._maybe_stop_and_crop(request) - assert result is True - assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED - assert len(request.output_token_ids) == max_tokens - assert list(request.output_token_ids) == output_token_ids[:max_tokens] + + # Test case 1: Stop on EOS token + requests = create_requests(num_requests=2, max_tokens=10) + for req in requests: + req.num_computed_tokens = req.num_tokens + scheduler.requests[req.request_id] = req + scheduler.running.append(req) + + scheduler_output = SchedulerOutput(scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 1, + requests[1].request_id: 2 + }, + total_num_scheduled_tokens=3, + scheduled_encoder_inputs={}, + use_spec_decode=True, + scheduled_spec_decode_tokens={ + requests[0].request_id: [], + requests[1].request_id: [10] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[]) + + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={ + req.request_id: i + for i, req in enumerate(requests) + }, + sampled_token_ids=[[EOS_TOKEN_ID], + [10, + 11]], # First request hits EOS, second continues + logprobs=None, + prompt_logprobs_dict={}) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify first request stopped, second continues + assert len(scheduler.running) == 1 + assert scheduler.running[0].request_id == requests[1].request_id + assert requests[0].status == RequestStatus.FINISHED_STOPPED + assert requests[0].request_id in scheduler.finished_req_ids + assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID] + assert list(requests[1].output_token_ids) == [10, 11] + + # Test case 2: Stop on custom stop token + scheduler = create_scheduler() + requests = create_requests(num_requests=2, + max_tokens=10, + stop_token_ids=[42, 43]) + for req in requests: + req.num_computed_tokens = req.num_tokens + scheduler.requests[req.request_id] = req + scheduler.running.append(req) + + scheduler_output = SchedulerOutput(scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 3, + requests[1].request_id: 2 + }, + total_num_scheduled_tokens=5, + scheduled_encoder_inputs={}, + use_spec_decode=True, + scheduled_spec_decode_tokens={ + requests[0].request_id: [10, 42], + requests[1].request_id: [13] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[]) + + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={ + req.request_id: i + for i, req in enumerate(requests) + }, + sampled_token_ids=[[10, 42, 12], + [13, 14]], # First request hits stop token + logprobs=None, + prompt_logprobs_dict={}) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify first request stopped on custom token + assert len(scheduler.running) == 1 + assert scheduler.running[0].request_id == requests[1].request_id + assert requests[0].status == RequestStatus.FINISHED_STOPPED + assert requests[0].stop_reason == 42 + assert requests[0].request_id in scheduler.finished_req_ids + assert list(requests[0].output_token_ids) == [10, 42] + assert list(requests[1].output_token_ids) == [13, 14] + + # Test case 3: Stop on max tokens + scheduler = create_scheduler() + requests = create_requests(num_requests=2, max_tokens=2) + for req in requests: + req.num_computed_tokens = req.num_tokens + scheduler.requests[req.request_id] = req + scheduler.running.append(req) + + scheduler_output = SchedulerOutput(scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 3, + requests[1].request_id: 1 + }, + total_num_scheduled_tokens=4, + scheduled_encoder_inputs={}, + use_spec_decode=True, + scheduled_spec_decode_tokens={ + requests[0].request_id: [10, 11], + requests[1].request_id: [] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[]) + + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={ + req.request_id: i + for i, req in enumerate(requests) + }, + sampled_token_ids=[[10, 11, 12], + [13]], # First request exceeds max_tokens + logprobs=None, + prompt_logprobs_dict={}) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify first request stopped due to length + assert len(scheduler.running) == 1 + assert scheduler.running[0].request_id == requests[1].request_id + assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED + assert requests[0].request_id in scheduler.finished_req_ids + assert list(requests[0].output_token_ids) == [10, 11 + ] # Truncated to max_tokens + assert list(requests[1].output_token_ids) == [13] + + # Test case 4: Ignore EOS flag + scheduler = create_scheduler() + requests = create_requests(num_requests=1, max_tokens=10) + requests[0].sampling_params.ignore_eos = True + requests[0].num_computed_tokens = requests[0].num_tokens + scheduler.requests[requests[0].request_id] = requests[0] + scheduler.running.append(requests[0]) + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={requests[0].request_id: 3}, + total_num_scheduled_tokens=3, + scheduled_encoder_inputs={}, + use_spec_decode=True, + scheduled_spec_decode_tokens={ + requests[0].request_id: [EOS_TOKEN_ID, 10] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[]) + + model_output = ModelRunnerOutput( + req_ids=[requests[0].request_id], + req_id_to_index={requests[0].request_id: 0}, + sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], + logprobs=None, + prompt_logprobs_dict={}) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify request continues past EOS + assert len(scheduler.running) == 1 + assert not requests[0].is_finished() + assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11] diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index ae2a1d617b11..f14e90b0015b 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -474,7 +474,7 @@ def update_from_output( req_index = model_runner_output.req_id_to_index[req_id] generated_token_ids = sampled_token_ids[req_index] - if not request.spec_token_ids: + if not scheduler_output.use_spec_decode: # When the request's num_computed_tokens catches up # its num_tokens, the request generates output tokens. # Otherwise, we ignore the sampler output for the request. @@ -484,19 +484,23 @@ def update_from_output( # num_computed_tokens_step is the number of tokens computed # in the current step. # num_computed_tokens_step = - # num_scheduled_tokens - num_tokens_rejected, + # num_scheduled_tokens - num_tokens_rejected, # where num_tokens_rejected = - # len(request.spec_token_ids) + 1 - len(generated_token_ids). + # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). # We use this way of calculating num_computed_tokens_step # because of chunked prefill. In chunked prefill, number of # computed tokens is not equal to the number of - # generated/sampled tokens. Here, len(request.spec_token_ids) + # generated/sampled tokens. Here, len(scheduled_spec_token_ids) # + 1 is the maximum number of tokens generated in the current - # step, len(request.spec_token_ids) + 1 - + # step, len(scheduled_spec_token_ids) + 1 - # len(generated_token_ids) is the number of tokens rejected # in the current step. + scheduled_spec_token_ids = ( + scheduler_output.scheduled_spec_decode_tokens.get( + req_id, [])) num_computed_tokens_step = num_scheduled_tokens[req_id] - ( - len(request.spec_token_ids) + 1 - len(generated_token_ids)) + len(scheduled_spec_token_ids) + 1 - + len(generated_token_ids)) request.num_computed_tokens += num_computed_tokens_step cached_encoder_input_ids = ( @@ -526,7 +530,7 @@ def update_from_output( new_logprobs = None new_token_ids = [] - if request.num_computed_tokens == request.num_tokens: + if request.num_computed_tokens >= request.num_tokens: for output_token_id in generated_token_ids: request.append_output_token_ids(output_token_id) new_token_ids.append(output_token_id) From 50ab1622670ebdaa2b414fd5c599058b19251bfd Mon Sep 17 00:00:00 2001 From: Lily Liu Date: Sat, 8 Feb 2025 13:31:24 -0800 Subject: [PATCH 45/75] Simplify request Co-authored-by: Nick Hill --- vllm/v1/request.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index dde1fc39c665..20b19e297879 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -98,8 +98,9 @@ def append_spec_token_ids( token_ids: Union[int, List[int]], ) -> None: if isinstance(token_ids, int): - token_ids = [token_ids] - self._spec_token_ids.extend(token_ids) + self._spec_token_ids.append(token_ids) + else: + self._spec_token_ids.extend(token_ids) def clear_spec_tokens(self) -> None: self._spec_token_ids = [] From f3b08f407278641c5d50479d0af2366e9bde0a83 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 8 Feb 2025 13:37:39 -0800 Subject: [PATCH 46/75] rejection sampling tests update --- tests/v1/sample/test_rejection_sampler.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 4107b4c0dbe3..a7127404e72d 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -165,18 +165,3 @@ def test_logits_shape_handling(sampler): expected = torch.tensor([[1, 2, 3]], dtype=torch.int, device=logits.device) assert torch.equal(output.sampled_token_ids, expected) assert logits.shape[-1] == vocab_size - - -def test_none_outputs(sampler): - """Test that other output fields are None as expected""" - spec_tokens = [[1]] - output_tokens = [1, 2] - - metadata = create_sampling_metadata(spec_tokens) - logits = create_logits_tensor(output_tokens) - - output = sampler.sample(logits, metadata) - assert output.logprob_token_ids is None - assert output.logprobs is None - assert output.prompt_logprob_token_ids is None - assert output.prompt_logprobs is None From 840413bf63eb8c76741200e8d8fdd65c7c5d863b Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 8 Feb 2025 14:23:02 -0800 Subject: [PATCH 47/75] optimize rejection sampler --- vllm/v1/sample/rejection_sampler.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 4630d6ca0e1a..71ed7abed469 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -97,17 +97,16 @@ def greedy_sample_ref( for spec_tokens in spec_token_ids: num_spec_tokens = len(spec_tokens) max_spec_len = max(max_spec_len, num_spec_tokens) - output_tokens = output_token_ids_cpu[ - output_token_start_idx:output_token_start_idx + 1 + - num_spec_tokens] i = 0 - while i < len(spec_tokens): - if spec_tokens[i] != output_tokens[i]: + while i < num_spec_tokens: + if spec_tokens[i] != output_token_ids_cpu[ + output_token_start_idx + i]: break i += 1 # +1 to include the bonus token. i += 1 - output_tokens = output_tokens[:i] + output_tokens = output_token_ids_cpu[ + output_token_start_idx:output_token_start_idx + i] sampled_token_ids.append(output_tokens) output_token_start_idx += num_spec_tokens + 1 From 03f6bee08f69292eb14b2ad717b2a1601fb4ccbc Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 8 Feb 2025 14:25:21 -0800 Subject: [PATCH 48/75] static --- vllm/v1/sample/rejection_sampler.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 71ed7abed469..6a816bc94c3c 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -25,12 +25,15 @@ def sample(self, logits: torch.Tensor, "Only greedy sampling is supported for now.") if is_flashinfer_available: - return self.flashinfer_sample(logits, sampling_metadata) + return RejectionSampler.flashinfer_sample(logits, + sampling_metadata) else: - return self.greedy_sample_ref(logits, sampling_metadata) + return RejectionSampler.greedy_sample_ref(logits, + sampling_metadata) + @staticmethod def flashinfer_sample( - self, logits: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> SamplerOutput: spec_token_ids = sampling_metadata.spec_token_ids spec_lengths = torch.tensor([len(s) for s in spec_token_ids], @@ -60,12 +63,10 @@ def flashinfer_sample( vocab_size = logits.size(-1) draft_token_ids = draft_token_ids.to(logits.device) - draft_probs = self._create_greedy_token_probs(draft_token_ids, - vocab_size, - logits.device) - target_probs = self._create_greedy_token_probs(target_token_ids, - vocab_size, - logits.device) + draft_probs = RejectionSampler._create_greedy_token_probs( + draft_token_ids, vocab_size, logits.device) + target_probs = RejectionSampler._create_greedy_token_probs( + target_token_ids, vocab_size, logits.device) uniform_samples = torch.zeros(batch_size, max_spec_len + 1, device=logits.device) @@ -79,8 +80,9 @@ def flashinfer_sample( return SamplerOutput(sampled_token_ids=sampled_token_ids, logprobs_tensors=None) + @staticmethod def greedy_sample_ref( - self, logits: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> SamplerOutput: # num_reqs x [num_speculated_tokens] spec_token_ids = sampling_metadata.spec_token_ids @@ -123,8 +125,8 @@ def greedy_sample_ref( return SamplerOutput(sampled_token_ids=sampled_token_ids, logprobs_tensors=None) - def _create_greedy_token_probs(self, token_ids: torch.Tensor, - vocab_size: int, + @staticmethod + def _create_greedy_token_probs(token_ids: torch.Tensor, vocab_size: int, out_device: torch.device) -> torch.Tensor: batch_size, num_tokens = token_ids.shape From 5fb9ac1a408cb11ac87abaf5669eb09b0b9b0e89 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 8 Feb 2025 14:27:52 -0800 Subject: [PATCH 49/75] format --- vllm/v1/engine/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 393fc1fe8579..bc4b25614a67 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -153,7 +153,7 @@ def propose_tokens(self): if req.num_computed_tokens < req.num_tokens - 1: continue # Ignore requests that already have spec tokens. - if len(req.spec_token_ids) > 0: + if req.spec_token_ids: continue spec_tokens = self.proposer.propose( req.all_token_ids, From 7c0497e67bba7bd3d2b1286cc191e24693541a92 Mon Sep 17 00:00:00 2001 From: Lily Liu Date: Sun, 9 Feb 2025 16:53:34 -0800 Subject: [PATCH 50/75] Update vllm/v1/core/scheduler.py --- vllm/v1/core/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index f14e90b0015b..de739fbd9afc 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -118,7 +118,7 @@ def schedule(self) -> "SchedulerOutput": scheduled_encoder_inputs: Dict[str, List[int]] = {} encoder_budget = self.max_num_encoder_input_tokens - # Spec Decode-related. + # Spec Decode-related. spec_decode: if any request in the scheduled batch uses speculative decoding. spec_decode = False scheduled_spec_decode_tokens: Dict[str, List[int]] = {} From e0bd8cc1ab715286406caf9b1218d94a9691c2e5 Mon Sep 17 00:00:00 2001 From: Lily Liu Date: Sun, 9 Feb 2025 16:54:13 -0800 Subject: [PATCH 51/75] Update vllm/v1/worker/gpu_model_runner.py Co-authored-by: Nick Hill --- vllm/v1/worker/gpu_model_runner.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5a8540020180..284323c2b0a9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -894,10 +894,7 @@ def execute_model( # Update with the actual token ids sampled_token_ids = sampled_token_ids.tolist() for i, req_state, seq_len, gen_len in request_seq_lens: - token_ids = [ - x for x in sampled_token_ids[i] if x != INVALID_TOKEN_ID - ] - sampled_token_ids[i] = token_ids + del sampled_token_ids[i][gen_len:] for j, token_id in enumerate(token_ids): self.input_batch.token_ids_cpu[i, seq_len - gen_len + j + 1] = token_id From 95d34f09283f141c2aeca6a23be372dab8ac0253 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sun, 9 Feb 2025 23:31:07 -0800 Subject: [PATCH 52/75] fix comments --- tests/v1/e2e/test_ngram_specdecode.py | 2 +- vllm/v1/core/kv_cache_manager.py | 3 +-- vllm/v1/core/scheduler.py | 8 ++++---- vllm/v1/engine/core.py | 3 ++- vllm/v1/outputs.py | 2 +- vllm/v1/request.py | 14 +++++--------- 6 files changed, 14 insertions(+), 18 deletions(-) diff --git a/tests/v1/e2e/test_ngram_specdecode.py b/tests/v1/e2e/test_ngram_specdecode.py index a4750eb62075..e821b1893d22 100644 --- a/tests/v1/e2e/test_ngram_specdecode.py +++ b/tests/v1/e2e/test_ngram_specdecode.py @@ -7,7 +7,7 @@ @pytest.fixture def test_prompts(): return [ - "Can you repeat the sentence ten times, this is a sentence?", + "Can you repeat the sentence ten times, this is a sentence.", "This is a basic spec decode test", ] diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 67ad19eada16..b649d607734a 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -223,8 +223,7 @@ def allocate_slots( if not self.enable_caching: return new_blocks - num_cached_blocks = self.cached_block_num[ - request.request_id] // self.block_size + num_cached_blocks = self.cached_block_num[request.request_id] num_full_blocks_after_append = (num_computed_tokens + num_tokens - len( request.spec_token_ids)) // self.block_size new_full_blocks = req_blocks[ diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index f14e90b0015b..dd327d57c17b 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -528,14 +528,14 @@ def update_from_output( stopped = False new_logprobs = None - new_token_ids = [] + new_token_ids: List[int] = [] if request.num_computed_tokens >= request.num_tokens: for output_token_id in generated_token_ids: request.append_output_token_ids(output_token_id) new_token_ids.append(output_token_id) - stopped = self._check_stop(request, output_token_id) + stopped = self._check_stop(request) # This must be called before we make the EngineCoreOutput. if stopped: self._free_request(request) @@ -544,7 +544,6 @@ def update_from_output( # Extract sample logprobs if needed. if request.sampling_params.logprobs is not None: assert logprobs is not None - req_index = model_runner_output.req_id_to_index[req_id] # NOTE: once we support N tokens per step (spec decode), # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) @@ -570,13 +569,14 @@ def update_from_output( scheduler_stats=self.make_stats(), ) - def _check_stop(self, request: Request, last_token_id: int) -> bool: + def _check_stop(self, request: Request) -> bool: if (request.num_tokens >= self.max_model_len or request.num_output_tokens >= request.max_tokens): request.status = RequestStatus.FINISHED_LENGTH_CAPPED return True sampling_params = request.sampling_params + last_token_id = request.output_token_ids[-1] if (not sampling_params.ignore_eos and last_token_id == request.eos_token_id): request.status = RequestStatus.FINISHED_STOPPED diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index bc4b25614a67..7ef40b0a3e99 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -72,6 +72,7 @@ def __init__( assert self.scheduler.speculative_config.ngram_prompt_lookup_min \ , "Only ngram spec decode is supported in V1." self.proposer = NgramProposer() + self.use_spec_decode = True def _initialize_kv_caches(self, vllm_config: VllmConfig) -> Tuple[int, int]: @@ -131,7 +132,7 @@ def step(self) -> EngineCoreOutputs: return EngineCoreOutputs( outputs=[], scheduler_stats=self.scheduler.make_stats()) - if self.scheduler.speculative_config: + if self.use_spec_decode: self.propose_tokens() scheduler_output = self.scheduler.schedule() diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 4c503748e514..fb6c4051e9a6 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -46,7 +46,7 @@ class SamplerOutput: # [num_reqs, max_num_generated_tokens] # Different requests can have different number of generated tokens. # All requests are padded to max_num_generated_tokens. - # INVALID_TOKEN_ID is used for padding. + # INVALID_TOKEN_ID (-1 by default) is used for padding. sampled_token_ids: torch.Tensor logprobs_tensors: Optional[LogprobsTensors] diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 20b19e297879..5752e81fb6cd 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -50,7 +50,7 @@ def __init__( self.num_prompt_tokens = len(self.prompt_token_ids) self._output_token_ids: List[int] = [] self._all_token_ids: List[int] = self.prompt_token_ids.copy() - self._spec_token_ids: List[int] = [] + self.spec_token_ids: List[int] = [] self.num_computed_tokens = 0 # Multi-modal related @@ -98,16 +98,12 @@ def append_spec_token_ids( token_ids: Union[int, List[int]], ) -> None: if isinstance(token_ids, int): - self._spec_token_ids.append(token_ids) + self.spec_token_ids.append(token_ids) else: - self._spec_token_ids.extend(token_ids) + self.spec_token_ids.extend(token_ids) def clear_spec_tokens(self) -> None: - self._spec_token_ids = [] - - @property - def spec_token_ids(self) -> List[int]: - return self._spec_token_ids + self.spec_token_ids = [] @property def num_tokens(self) -> int: @@ -115,7 +111,7 @@ def num_tokens(self) -> int: @property def num_tokens_with_spec(self) -> int: - return len(self._all_token_ids) + len(self._spec_token_ids) + return len(self._all_token_ids) + len(self.spec_token_ids) @property def num_output_tokens(self) -> int: From 4cc5f8db861eb18723d3735401f9d50529a1f01d Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sun, 9 Feb 2025 23:45:10 -0800 Subject: [PATCH 53/75] minor --- vllm/v1/worker/gpu_model_runner.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 284323c2b0a9..3290b3d378c3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -363,8 +363,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: self.input_batch.condense(removed_req_indices) return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 - def _prepare_inputs(self, scheduler_output: "SchedulerOutput") \ - -> Tuple[FlashAttentionMetadata, torch.Tensor]: + def _prepare_inputs( + self, scheduler_output: "SchedulerOutput" + ) -> Tuple[FlashAttentionMetadata, torch.Tensor]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs @@ -895,7 +896,7 @@ def execute_model( sampled_token_ids = sampled_token_ids.tolist() for i, req_state, seq_len, gen_len in request_seq_lens: del sampled_token_ids[i][gen_len:] - for j, token_id in enumerate(token_ids): + for j, token_id in enumerate(sampled_token_ids[i]): self.input_batch.token_ids_cpu[i, seq_len - gen_len + j + 1] = token_id req_state.output_token_ids[-gen_len + j] = token_id From 4086a77991921e4699b3cef41efcf2e52c2b0f52 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 10 Feb 2025 15:38:28 -0800 Subject: [PATCH 54/75] input prepare --- vllm/v1/worker/gpu_model_runner.py | 57 +++++++++++++++++------------- 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 64308b3f2c3d..0e6ad004f8de 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -179,6 +179,7 @@ def __init__( self.max_model_len, self.max_num_tokens), dtype=np.int32) + self.arange_cpu = torch.from_numpy(self.arange_np) # NOTE(woosuk): These tensors are "stateless", i.e., they are literally # a faster version of creating a new tensor every time. Thus, we should # not make any assumptions about the values in these tensors. @@ -380,12 +381,19 @@ def _prepare_inputs( # TODO: The Python loop can be slow. Optimize. num_scheduled_tokens_list: List[int] = [] max_num_scheduled_tokens = 0 + all_spec_token_ids: List[int] = [] + num_spec_tokens: List[int] = [] for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): assert req_id is not None num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens_list.append(num_tokens) max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, []) + all_spec_token_ids.extend(spec_token_ids) + num_spec_tokens.append(len(spec_token_ids)) + num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list, dtype=np.int32) assert max_num_scheduled_tokens > 0 @@ -425,28 +433,29 @@ def _prepare_inputs( token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) - # Add spec decode tokens to input_batch.token_ids_cpu. - # Get spec decode logits indices. - spec_query_end_loc = 0 - spec_decode_logits_indices: List[int] = [] - for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): - assert req_id is not None - req_num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] - num_compute_tokens = self.input_batch.num_computed_tokens_cpu[i] - spec_query_end_loc += req_num_scheduled_tokens - spec_token_ids = ( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) - for j, spec_token_id in enumerate(spec_token_ids): - # +1 here because the input for verification is - # [last_output_token_id] + spec_token_ids - self.input_batch.token_ids_cpu[i, num_compute_tokens + 1 + - j] = spec_token_id - # -1 here because the input for verification is - # [last_output_token_id] + spec_token_ids - spec_decode_logits_indices.extend( - range(spec_query_end_loc - len(spec_token_ids) - 1, - spec_query_end_loc)) + if scheduler_output.use_spec_decode and all_spec_token_ids: + # Currently, we assume all speculated tokens are verified. + # We don't support partial verification. + # 1. Get spec decode logits indices. + spec_decode_logits_indices = self.arange_cpu[: + total_num_scheduled_tokens] + # 2. Write spec token ids to input_ids_cpu. + all_spec_token_ids = torch.tensor(all_spec_token_ids, + dtype=torch.int32) + # Step 1. Calculate the spec token indices within input_ids_cpu. + # E.g., num_spec_tokens: [3, 0, 2, 0, 1] + # num_scheduled_tokens: [4, 1, 3, 1, 2] + # cu_num_tokens - num_scheduled_tokens: [0, 4, 5, 8, 9] + # cumsums_spec_offsets [0, 0, 0, 5, 5, 9] + cumsums_spec_offsets = np.repeat( + cu_num_tokens - num_scheduled_tokens, num_spec_tokens) + + # Step 2. Write spec token ids to input_ids_cpu. + torch.index_select( + all_spec_token_ids, + 0, + torch.from_numpy(cumsums_spec_offsets), + out=self.input_ids_cpu[:total_num_scheduled_tokens]) # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large @@ -542,8 +551,8 @@ def _prepare_inputs( ) if scheduler_output.use_spec_decode: - logits_indices = torch.tensor(spec_decode_logits_indices, - device=self.device) + logits_indices = spec_decode_logits_indices.to(self.device, + non_blocking=True) else: # NOTE(woosuk): Due to chunked prefills, the batch may contain # partial requests. While we should not sample any token From 1e218af556b2e569019d72db3a6e1ae2c06702d4 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 10 Feb 2025 21:19:52 -0800 Subject: [PATCH 55/75] fix input prepare --- vllm/v1/engine/core.py | 1 + vllm/v1/sample/rejection_sampler.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 47 ++++++++++++++++++----------- 3 files changed, 32 insertions(+), 18 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index df1b0b78b62a..20eb2cba6afa 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -66,6 +66,7 @@ def __init__( vllm_config.model_config) # TODO: find a better way to check if we are using ngram. + self.use_spec_decode = False if self.scheduler.speculative_config: assert self.scheduler.speculative_config.ngram_prompt_lookup_min \ , "Only ngram spec decode is supported in V1." diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 6a816bc94c3c..2d543c05db4e 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -22,7 +22,7 @@ def sample(self, logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> SamplerOutput: if not sampling_metadata.all_greedy: raise NotImplementedError( - "Only greedy sampling is supported for now.") + "Only greedy sampling is supported by rejection sampler.") if is_flashinfer_available: return RejectionSampler.flashinfer_sample(logits, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0e6ad004f8de..4e38dc54e932 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -382,7 +382,7 @@ def _prepare_inputs( num_scheduled_tokens_list: List[int] = [] max_num_scheduled_tokens = 0 all_spec_token_ids: List[int] = [] - num_spec_tokens: List[int] = [] + num_spec_tokens_list: List[int] = [] for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): assert req_id is not None num_tokens = scheduler_output.num_scheduled_tokens[req_id] @@ -392,7 +392,7 @@ def _prepare_inputs( spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( req_id, []) all_spec_token_ids.extend(spec_token_ids) - num_spec_tokens.append(len(spec_token_ids)) + num_spec_tokens_list.append(len(spec_token_ids)) num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list, dtype=np.int32) @@ -433,29 +433,42 @@ def _prepare_inputs( token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) - if scheduler_output.use_spec_decode and all_spec_token_ids: + if all_spec_token_ids: # Currently, we assume all speculated tokens are verified. # We don't support partial verification. # 1. Get spec decode logits indices. spec_decode_logits_indices = self.arange_cpu[: total_num_scheduled_tokens] - # 2. Write spec token ids to input_ids_cpu. - all_spec_token_ids = torch.tensor(all_spec_token_ids, - dtype=torch.int32) # Step 1. Calculate the spec token indices within input_ids_cpu. - # E.g., num_spec_tokens: [3, 0, 2, 0, 1] - # num_scheduled_tokens: [4, 1, 3, 1, 2] - # cu_num_tokens - num_scheduled_tokens: [0, 4, 5, 8, 9] - # cumsums_spec_offsets [0, 0, 0, 5, 5, 9] - cumsums_spec_offsets = np.repeat( - cu_num_tokens - num_scheduled_tokens, num_spec_tokens) + # E.g., num_spec_tokens_list: [3, 0, 2, 0, 1] + # spec_req_indices: [0, 0, 0, 2, 2, 4] + spec_req_indices = np.repeat(self.arange_np[:num_reqs], + num_spec_tokens_list) + # spec_offsets: offsets within each spec token list. + # E.g., [1, 2, 3, 1, 2, 1], TODO: avoid the for loop here + spec_offsets = np.concatenate( + [self.arange_np[1:val + 1] for val in num_spec_tokens_list]) + # spec_seq_offsets: offsets within each sequence. + # E.g., num_computed_tokens_cpu: [1, 4, 3, 6, 2] + # after repeating: [1, 1, 1, 3, 3, 2] + # spec_seq_offsets: [1, 1, 1, 3, 3, 2] + [1, 2, 3, 1, 2, 1] + # = [2, 3, 4, 4, 5, 3] + spec_seq_offsets = np.repeat( + self.input_batch.num_computed_tokens_cpu[:num_reqs], + num_spec_tokens_list) + spec_offsets + # cumsums_spec_offsets: [0, 0, 0, 2M, 2M, 4M] + [2, 3, 4, 4, 5, 3] + cumsums_spec_offsets = ( + spec_seq_offsets + + spec_req_indices * self.input_batch.token_ids_cpu.shape[1]) + cumsums_spec_offsets = torch.from_numpy(cumsums_spec_offsets).to( + torch.int64) + all_spec_token_ids = torch.tensor(all_spec_token_ids, + device=self.input_ids_cpu.device, + dtype=self.input_ids_cpu.dtype) # Step 2. Write spec token ids to input_ids_cpu. - torch.index_select( - all_spec_token_ids, - 0, - torch.from_numpy(cumsums_spec_offsets), - out=self.input_ids_cpu[:total_num_scheduled_tokens]) + self.input_batch.token_ids_cpu_tensor.flatten().scatter_( + 0, cumsums_spec_offsets, all_spec_token_ids) # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large From 633567a4f949e57ed7c8013d5f469cc48f5abb9e Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 10 Feb 2025 21:30:03 -0800 Subject: [PATCH 56/75] simplify scheduleroutput --- tests/v1/core/test_scheduler.py | 4 ---- tests/v1/e2e/test_ngram_specdecode.py | 2 +- vllm/v1/core/kv_cache_manager.py | 3 +-- vllm/v1/core/scheduler.py | 15 +++++++-------- vllm/v1/worker/gpu_model_runner.py | 5 +++-- 5 files changed, 12 insertions(+), 17 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 0782a9d9dc5c..ccf34edd3384 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -240,7 +240,6 @@ def test_stop_via_update_from_output(): }, total_num_scheduled_tokens=3, scheduled_encoder_inputs={}, - use_spec_decode=True, scheduled_spec_decode_tokens={ requests[0].request_id: [], requests[1].request_id: [10] @@ -289,7 +288,6 @@ def test_stop_via_update_from_output(): }, total_num_scheduled_tokens=5, scheduled_encoder_inputs={}, - use_spec_decode=True, scheduled_spec_decode_tokens={ requests[0].request_id: [10, 42], requests[1].request_id: [13] @@ -336,7 +334,6 @@ def test_stop_via_update_from_output(): }, total_num_scheduled_tokens=4, scheduled_encoder_inputs={}, - use_spec_decode=True, scheduled_spec_decode_tokens={ requests[0].request_id: [10, 11], requests[1].request_id: [] @@ -381,7 +378,6 @@ def test_stop_via_update_from_output(): num_scheduled_tokens={requests[0].request_id: 3}, total_num_scheduled_tokens=3, scheduled_encoder_inputs={}, - use_spec_decode=True, scheduled_spec_decode_tokens={ requests[0].request_id: [EOS_TOKEN_ID, 10] }, diff --git a/tests/v1/e2e/test_ngram_specdecode.py b/tests/v1/e2e/test_ngram_specdecode.py index e821b1893d22..936d9336640d 100644 --- a/tests/v1/e2e/test_ngram_specdecode.py +++ b/tests/v1/e2e/test_ngram_specdecode.py @@ -15,7 +15,7 @@ def test_prompts(): @pytest.fixture def sampling_config(): # Only support greedy for now - return SamplingParams(temperature=0, max_tokens=100, ignore_eos=False) + return SamplingParams(temperature=0, max_tokens=50, ignore_eos=False) @pytest.fixture diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 3acfc577541b..b7f04b6ecbc5 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -79,8 +79,7 @@ def __init__( str, List[BlockHashType]] = defaultdict(list) # {req_id: The number of cached blocks for this given request} - # This is used to track the number of cached blocks for each request, - # currently only used for speculative decoding. + # This is used to track the number of cached blocks for each request. self.cached_block_num: Dict[str, int] = defaultdict(int) @property diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 43b12fcb6d54..b52e64e9d194 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -118,8 +118,8 @@ def schedule(self) -> "SchedulerOutput": scheduled_encoder_inputs: Dict[str, List[int]] = {} encoder_budget = self.max_num_encoder_input_tokens - # Spec Decode-related. spec_decode: if any request in the scheduled batch uses speculative decoding. - spec_decode = False + # Spec Decode-related. spec_decode: if any request in + # the scheduled batch uses speculative decoding. scheduled_spec_decode_tokens: Dict[str, List[int]] = {} # First, schedule the RUNNING requests. @@ -191,9 +191,8 @@ def schedule(self) -> "SchedulerOutput": # Speculative decode related. if request.spec_token_ids: - spec_decode = True - scheduled_spec_decode_tokens[ - request.request_id] = request.spec_token_ids + scheduled_spec_decode_tokens[ + request.request_id] = request.spec_token_ids # Record the LoRAs in scheduled_running_reqs requested_loras: Set[int] = set() @@ -342,7 +341,6 @@ def schedule(self) -> "SchedulerOutput": num_scheduled_tokens=num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens, scheduled_encoder_inputs=scheduled_encoder_inputs, - use_spec_decode=spec_decode, scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, num_common_prefix_blocks=num_common_prefix_blocks, # finished_req_ids is an existing state in the scheduler, @@ -464,6 +462,8 @@ def update_from_output( # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below # loop can be a performance bottleneck. We should do our best to avoid # expensive operations inside the loop. + use_spec_decode = len( + scheduler_output.scheduled_spec_decode_tokens) > 0 for request in self.running: req_id = request.request_id num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) @@ -474,7 +474,7 @@ def update_from_output( req_index = model_runner_output.req_id_to_index[req_id] generated_token_ids = sampled_token_ids[req_index] - if not scheduler_output.use_spec_decode: + if not use_spec_decode: # When the request's num_computed_tokens catches up # its num_tokens, the request generates output tokens. # Otherwise, we ignore the sampler output for the request. @@ -717,7 +717,6 @@ class SchedulerOutput: num_scheduled_tokens: Dict[str, int] total_num_scheduled_tokens: int scheduled_encoder_inputs: Dict[str, List[int]] - use_spec_decode: bool scheduled_spec_decode_tokens: Dict[str, List[int]] num_common_prefix_blocks: int diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4e38dc54e932..d8a2675b293d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -433,7 +433,8 @@ def _prepare_inputs( token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) - if all_spec_token_ids: + use_spec_decode = len(all_spec_token_ids) > 0 + if use_spec_decode: # Currently, we assume all speculated tokens are verified. # We don't support partial verification. # 1. Get spec decode logits indices. @@ -563,7 +564,7 @@ def _prepare_inputs( suffix_kv_lens=suffix_kv_lens, ) - if scheduler_output.use_spec_decode: + if use_spec_decode: logits_indices = spec_decode_logits_indices.to(self.device, non_blocking=True) else: From 888f183e5449cfe09909aac733db0a00b7299f7a Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 10 Feb 2025 21:40:34 -0800 Subject: [PATCH 57/75] change test case to make output more deterministic --- tests/v1/e2e/test_ngram_specdecode.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/v1/e2e/test_ngram_specdecode.py b/tests/v1/e2e/test_ngram_specdecode.py index 936d9336640d..a45ed173752f 100644 --- a/tests/v1/e2e/test_ngram_specdecode.py +++ b/tests/v1/e2e/test_ngram_specdecode.py @@ -8,14 +8,14 @@ def test_prompts(): return [ "Can you repeat the sentence ten times, this is a sentence.", - "This is a basic spec decode test", + "Can you repeat the sentence ten times, the future of AI is right.", ] @pytest.fixture def sampling_config(): # Only support greedy for now - return SamplingParams(temperature=0, max_tokens=50, ignore_eos=False) + return SamplingParams(temperature=0, max_tokens=30, ignore_eos=False) @pytest.fixture From 353c3726eff60e96f9d2d4686131b06388773f04 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 10 Feb 2025 23:28:24 -0800 Subject: [PATCH 58/75] update cpu gpu sync --- vllm/v1/sample/sampler.py | 3 +++ vllm/v1/worker/gpu_model_runner.py | 33 +++++++++++++++--------------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index b961cf01949a..2d6cbf5ad1e1 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -64,6 +64,9 @@ def forward( # These are GPU tensors. sampler_output = SamplerOutput( + # The sampled tokens are expanded to 2D tensor with shape + # [num_requests, 1], where each row represents one generated + # token per request. sampled_token_ids=sampled.unsqueeze(-1), logprobs_tensors=logprobs_tensors, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d8a2675b293d..f94ebf06cc6a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -905,7 +905,7 @@ def execute_model( # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. num_reqs = self.input_batch.num_reqs - request_seq_lens: List[Tuple[int, CachedRequestState, int, int]] = [] + request_seq_lens: List[Tuple[int, CachedRequestState, int]] = [] for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): assert req_id is not None req_state = self.requests[req_id] @@ -914,11 +914,7 @@ def execute_model( if seq_len >= req_state.num_tokens: # We don't rewind the generator state for requests now # because spec decode only supports greedy decoding for now. - gen_len = (sampled_token_ids[i] - != INVALID_TOKEN_ID).sum().item() - self.input_batch.num_tokens[i] += gen_len - req_state.output_token_ids.extend([0] * gen_len) - request_seq_lens.append((i, req_state, seq_len, gen_len)) + request_seq_lens.append((i, req_state, seq_len)) else: # Ignore the sampled token from the partial request. # Rewind the generator state as if the token was not sampled. @@ -933,9 +929,9 @@ def execute_model( self.input_batch.req_ids[:num_reqs]), "req_ids contains None" req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) + logprobs_tensors = sampler_output.logprobs_tensors # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. - logprobs_tensors = sampler_output.logprobs_tensors logprobs_lists = logprobs_tensors.tolists() \ if logprobs_tensors is not None else None @@ -945,19 +941,24 @@ def execute_model( scheduler_output, ) - # Update with the actual token ids - sampled_token_ids = sampled_token_ids.tolist() - for i, req_state, seq_len, gen_len in request_seq_lens: - del sampled_token_ids[i][gen_len:] - for j, token_id in enumerate(sampled_token_ids[i]): - self.input_batch.token_ids_cpu[i, seq_len - gen_len + j + - 1] = token_id - req_state.output_token_ids[-gen_len + j] = token_id + # Update batch with the valid generated tokens. + valid_mask = sampled_token_ids != INVALID_TOKEN_ID + gen_lens = valid_mask.sum(dim=1).tolist() + valid_sampled_token_ids = [ + seq.tolist() + for seq in sampled_token_ids[valid_mask].split(gen_lens) + ] + self.input_batch.num_tokens[:num_reqs] += gen_lens + for i, req_state, seq_len in request_seq_lens: + target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1) + self.input_batch.token_ids_cpu[ + i, target_slice] = valid_sampled_token_ids[i] + req_state.output_token_ids.extend(valid_sampled_token_ids[i]) model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=sampled_token_ids, + sampled_token_ids=valid_sampled_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, ) From 9416792b589d91066692a77ca31819b4bade9f39 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 11 Feb 2025 01:00:19 -0800 Subject: [PATCH 59/75] vectorize rejection sampler --- vllm/v1/sample/rejection_sampler.py | 81 +++++++++++++++-------------- vllm/v1/worker/gpu_model_runner.py | 7 ++- 2 files changed, 49 insertions(+), 39 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 2d543c05db4e..9e610f552352 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import torch import torch.nn as nn +from torch.nn.utils.rnn import pad_sequence from vllm.logger import init_logger from vllm.v1.outputs import SamplerOutput @@ -84,45 +85,49 @@ def flashinfer_sample( def greedy_sample_ref( logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> SamplerOutput: - # num_reqs x [num_speculated_tokens] - spec_token_ids = sampling_metadata.spec_token_ids - # only argmax is supported for now - output_token_ids_cpu = logits.argmax(dim=-1).view(-1).tolist() - - sampled_token_ids = [] - # Stop at the first mismatch place. - # spec_tokens: [1, 2, 3] - # output_tokens: [1, 2, 4, 5] - # sampled_tokens: [1, 2, 4] - output_token_start_idx = 0 - max_spec_len = -1 - for spec_tokens in spec_token_ids: - num_spec_tokens = len(spec_tokens) - max_spec_len = max(max_spec_len, num_spec_tokens) - i = 0 - while i < num_spec_tokens: - if spec_tokens[i] != output_token_ids_cpu[ - output_token_start_idx + i]: - break - i += 1 - # +1 to include the bonus token. - i += 1 - output_tokens = output_token_ids_cpu[ - output_token_start_idx:output_token_start_idx + i] - sampled_token_ids.append(output_tokens) - output_token_start_idx += num_spec_tokens + 1 - - sampled_token_ids = [ - x + [INVALID_TOKEN_ID] * (max_spec_len + 1 - len(x)) - for x in sampled_token_ids + spec_lens = [len(x) for x in sampling_metadata.spec_token_ids] + # Add 1 to include the 'bonus' token + sample_lens = [x + 1 for x in spec_lens] + + output_token_ids = logits.argmax(dim=-1).view(-1) + output_token_ids = output_token_ids.split(sample_lens) + output_token_ids = pad_sequence(output_token_ids, + batch_first=True, + padding_value=INVALID_TOKEN_ID) + + # Convert spec token IDs to a tensor, split by sample_lens, then pad. + spec_token_ids = [ + torch.tensor(x, + dtype=output_token_ids.dtype, + device=output_token_ids.device) + for x in sampling_metadata.spec_token_ids ] - sampled_token_ids = torch.tensor(sampled_token_ids, - device=logits.device, - dtype=torch.int) - - assert output_token_start_idx == len(output_token_ids_cpu) - - return SamplerOutput(sampled_token_ids=sampled_token_ids, + spec_token_ids = pad_sequence(spec_token_ids, + batch_first=True, + padding_value=INVALID_TOKEN_ID) + + # Produce a mask that remains 1 (True) until the first + # mismatch (cumprod turns 0 after a mismatch). + accept_mask = (output_token_ids[:, :-1] == spec_token_ids).cumprod( + dim=1) + # Identify valid positions (non-padding) + valid_mask = output_token_ids != INVALID_TOKEN_ID + + # Generate mask with bonus token. + generate_mask = torch.cat([ + accept_mask, + torch.zeros(accept_mask.size(0), 1, device=accept_mask.device) + ], + dim=1).to(torch.bool) & valid_mask + zeros_mask = (generate_mask == 0) + first_zero_idx = zeros_mask.float().argmax(dim=1) + # Figure out which rows actually contain at least one zero. + rows_with_zero = zeros_mask.any(dim=1) + # Use indexing to set the first zero in each of those rows to 1. + generate_mask[rows_with_zero, first_zero_idx[rows_with_zero]] = 1 + + output_token_ids[~generate_mask] = INVALID_TOKEN_ID + return SamplerOutput(sampled_token_ids=output_token_ids, logprobs_tensors=None) @staticmethod diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f94ebf06cc6a..c3cc2c7f04b9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -440,7 +440,12 @@ def _prepare_inputs( # 1. Get spec decode logits indices. spec_decode_logits_indices = self.arange_cpu[: total_num_scheduled_tokens] - # Step 1. Calculate the spec token indices within input_ids_cpu. + # 2. Write spec_token_ids to input batch. + # Step 1. Get req indices that perform spec decode and repeat + # the req indices by the number of spec tokens. Note + # for requests that don't perform spec decode, the + # number of spec tokens is 0 and the req index is + # repeated 0 times. # E.g., num_spec_tokens_list: [3, 0, 2, 0, 1] # spec_req_indices: [0, 0, 0, 2, 2, 4] spec_req_indices = np.repeat(self.arange_np[:num_reqs], From 54c5fa52e12dd288fed47fd6b9113ce4ad8a1168 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 11 Feb 2025 14:29:09 -0800 Subject: [PATCH 60/75] fix comments --- .../{test_ngram_specdecode.py => test_ngram_spec_decode.py} | 0 vllm/v1/core/scheduler.py | 6 ++---- vllm/v1/sample/rejection_sampler.py | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) rename tests/v1/e2e/{test_ngram_specdecode.py => test_ngram_spec_decode.py} (100%) diff --git a/tests/v1/e2e/test_ngram_specdecode.py b/tests/v1/e2e/test_ngram_spec_decode.py similarity index 100% rename from tests/v1/e2e/test_ngram_specdecode.py rename to tests/v1/e2e/test_ngram_spec_decode.py diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index e36316f46dec..f93fafffbeb4 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -456,8 +456,6 @@ def update_from_output( # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below # loop can be a performance bottleneck. We should do our best to avoid # expensive operations inside the loop. - use_spec_decode = len( - scheduler_output.scheduled_spec_decode_tokens) > 0 for request in self.running: req_id = request.request_id num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) @@ -468,7 +466,7 @@ def update_from_output( req_index = model_runner_output.req_id_to_index[req_id] generated_token_ids = sampled_token_ids[req_index] - if not use_spec_decode: + if req_id not in scheduler_output.scheduled_spec_decode_tokens: # When the request's num_computed_tokens catches up # its num_tokens, the request generates output tokens. # Otherwise, we ignore the sampler output for the request. @@ -529,8 +527,8 @@ def update_from_output( request.append_output_token_ids(output_token_id) new_token_ids.append(output_token_id) - stopped = self._check_stop(request) # This must be called before we make the EngineCoreOutput. + stopped = self._check_stop(request) if stopped: self._free_request(request) break diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 9e610f552352..d9e8ff86fa87 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -19,7 +19,7 @@ class RejectionSampler(nn.Module): - def sample(self, logits: torch.Tensor, + def forward(self, logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> SamplerOutput: if not sampling_metadata.all_greedy: raise NotImplementedError( From 4ea2fdaff4055bc2587a69cf7907951d8d330991 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 11 Feb 2025 14:37:13 -0800 Subject: [PATCH 61/75] minor --- vllm/v1/sample/rejection_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index d9e8ff86fa87..08b04c8ff202 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -20,7 +20,7 @@ class RejectionSampler(nn.Module): def forward(self, logits: torch.Tensor, - sampling_metadata: SamplingMetadata) -> SamplerOutput: + sampling_metadata: SamplingMetadata) -> SamplerOutput: if not sampling_metadata.all_greedy: raise NotImplementedError( "Only greedy sampling is supported by rejection sampler.") From 0d6d713e2010181ad4e2faa0fcfd6d6d77adedf1 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 11 Feb 2025 15:21:22 -0800 Subject: [PATCH 62/75] minor --- vllm/v1/sample/rejection_sampler.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 08b04c8ff202..8aa9cb291a90 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -57,7 +57,7 @@ def flashinfer_sample( draft_token_ids[i, :num_spec_tokens] = torch.tensor( spec_token_ids[i], device="cpu", dtype=torch.long) end_loc = start_loc + num_spec_tokens + 1 - # Assume greedy sampling here + # Assume greedy sampling. target_token_ids[i, :num_spec_tokens + 1] = torch.argmax( logits[start_loc:end_loc], dim=-1) start_loc = end_loc @@ -86,7 +86,7 @@ def greedy_sample_ref( logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> SamplerOutput: spec_lens = [len(x) for x in sampling_metadata.spec_token_ids] - # Add 1 to include the 'bonus' token + # Add 1 to include the 'bonus' token. sample_lens = [x + 1 for x in spec_lens] output_token_ids = logits.argmax(dim=-1).view(-1) @@ -110,9 +110,8 @@ def greedy_sample_ref( # mismatch (cumprod turns 0 after a mismatch). accept_mask = (output_token_ids[:, :-1] == spec_token_ids).cumprod( dim=1) - # Identify valid positions (non-padding) + # Identify valid positions (non-padding). valid_mask = output_token_ids != INVALID_TOKEN_ID - # Generate mask with bonus token. generate_mask = torch.cat([ accept_mask, @@ -141,7 +140,7 @@ def _create_greedy_token_probs(token_ids: torch.Tensor, vocab_size: int, dtype=torch.float, device=out_device) - # Ignore INVALID_TOKEN_ID + # Ignore INVALID_TOKEN_ID. valid_mask = (token_ids != INVALID_TOKEN_ID) valid_indices = token_ids.clone() valid_indices[~valid_mask] = 0 From af7322e89fb924ea1fc9daa3a9fc50999f8fd826 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Wed, 12 Feb 2025 06:14:06 +0000 Subject: [PATCH 63/75] fix input prepare bug --- vllm/v1/sample/sampler.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 39 ++++++++++++++++++++++++++---- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 2d6cbf5ad1e1..632660e9fb79 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -30,7 +30,7 @@ def forward( if sampling_metadata.max_num_logprobs: raise NotImplementedError( "Rejection sampling does not support logprobs.") - return self.rejection_sampler.sample( + return self.rejection_sampler( logits, sampling_metadata, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7f82ebac96f7..ec51229f00c6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -437,10 +437,8 @@ def _prepare_inputs( if use_spec_decode: # Currently, we assume all speculated tokens are verified. # We don't support partial verification. - # 1. Get spec decode logits indices. - spec_decode_logits_indices = self.arange_cpu[: - total_num_scheduled_tokens] - # 2. Write spec_token_ids to input batch. + + # 1. Write spec_token_ids to input batch. # Step 1. Get req indices that perform spec decode and repeat # the req indices by the number of spec tokens. Note # for requests that don't perform spec decode, the @@ -475,6 +473,37 @@ def _prepare_inputs( # Step 2. Write spec token ids to input_ids_cpu. self.input_batch.token_ids_cpu_tensor.flatten().scatter_( 0, cumsums_spec_offsets, all_spec_token_ids) + + + # 2. Get spec decode logits indices. + # E.g., num_scheduled_tokens: [4, 100, 3, 100, 2] + # cu_num_tokens: [4, 104, 107, 207, 209] + # num_spec_tokens_list: [3, 0, 2, 0, 1] + # num_sampled_tokens: [4, 1, 3, 1, 2] + # spec_decode_logits_indices: + # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] + num_spec_tokens_np = np.array(num_spec_tokens_list, dtype=np.int32) + num_sampled_tokens = num_spec_tokens_np + 1 + # logits_start_loc: [0, 103, 104, 206, 207] + logits_start_loc = cu_num_tokens - num_sampled_tokens + # [0, 103, 104, 206, 207] -> [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] + logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens) + # The following three lines: + # [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] + # Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11] + cu_num_sampled_tokens = np.cumsum(num_sampled_tokens) + # Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9] -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] + cumsums_sampled_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens, + num_sampled_tokens) + # Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] + # -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] + total_num_sampled_tokens = num_sampled_tokens.sum() + sampled_arange = self.arange_np[:total_num_sampled_tokens] - cumsums_sampled_offsets + + # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] -> + # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] + spec_decode_logits_indices = logits_start_loc + sampled_arange + # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large @@ -570,7 +599,7 @@ def _prepare_inputs( ) if use_spec_decode: - logits_indices = spec_decode_logits_indices.to(self.device, + logits_indices = torch.from_numpy(spec_decode_logits_indices).to(self.device, non_blocking=True) else: # NOTE(woosuk): Due to chunked prefills, the batch may contain From 4dec71deb1a3f2d1e66ef42187d6e45fb21e4918 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Wed, 12 Feb 2025 06:31:59 +0000 Subject: [PATCH 64/75] fix --- vllm/v1/worker/gpu_model_runner.py | 32 ++++++++++++++++-------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ec51229f00c6..36d9b091bb73 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -437,7 +437,7 @@ def _prepare_inputs( if use_spec_decode: # Currently, we assume all speculated tokens are verified. # We don't support partial verification. - + # 1. Write spec_token_ids to input batch. # Step 1. Get req indices that perform spec decode and repeat # the req indices by the number of spec tokens. Note @@ -473,37 +473,39 @@ def _prepare_inputs( # Step 2. Write spec token ids to input_ids_cpu. self.input_batch.token_ids_cpu_tensor.flatten().scatter_( 0, cumsums_spec_offsets, all_spec_token_ids) - - + # 2. Get spec decode logits indices. # E.g., num_scheduled_tokens: [4, 100, 3, 100, 2] # cu_num_tokens: [4, 104, 107, 207, 209] - # num_spec_tokens_list: [3, 0, 2, 0, 1] + # num_spec_tokens_list: [3, 0, 2, 0, 1] # num_sampled_tokens: [4, 1, 3, 1, 2] # spec_decode_logits_indices: - # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] + # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] num_spec_tokens_np = np.array(num_spec_tokens_list, dtype=np.int32) num_sampled_tokens = num_spec_tokens_np + 1 # logits_start_loc: [0, 103, 104, 206, 207] logits_start_loc = cu_num_tokens - num_sampled_tokens - # [0, 103, 104, 206, 207] -> [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] + # [0, 103, 104, 206, 207] -> + # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens) # The following three lines: # [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] # Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11] cu_num_sampled_tokens = np.cumsum(num_sampled_tokens) - # Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9] -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] - cumsums_sampled_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens, - num_sampled_tokens) - # Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] + # Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9] + # -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] + cumsums_sampled_offsets = np.repeat( + cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens) + # Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + # - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] # -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] total_num_sampled_tokens = num_sampled_tokens.sum() - sampled_arange = self.arange_np[:total_num_sampled_tokens] - cumsums_sampled_offsets - + sampled_arange = (self.arange_np[:total_num_sampled_tokens] - + cumsums_sampled_offsets) + # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] -> # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] spec_decode_logits_indices = logits_start_loc + sampled_arange - # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large @@ -599,8 +601,8 @@ def _prepare_inputs( ) if use_spec_decode: - logits_indices = torch.from_numpy(spec_decode_logits_indices).to(self.device, - non_blocking=True) + logits_indices = torch.from_numpy(spec_decode_logits_indices).to( + self.device, non_blocking=True) else: # NOTE(woosuk): Due to chunked prefills, the batch may contain # partial requests. While we should not sample any token From 8758b9675c63a1a47c1af0ccda14f337990e3d56 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 12 Feb 2025 19:49:36 +0000 Subject: [PATCH 65/75] fix test --- tests/v1/sample/test_rejection_sampler.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index a7127404e72d..94c40027729f 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -53,7 +53,7 @@ def test_perfect_match(sampler): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - output = sampler.sample(logits, metadata) + output = sampler(logits, metadata) expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device) @@ -68,7 +68,7 @@ def test_early_mismatch(sampler): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - output = sampler.sample(logits, metadata) + output = sampler(logits, metadata) expected = torch.tensor([[1, 5, INVALID_TOKEN_ID, INVALID_TOKEN_ID]], dtype=torch.int, device=logits.device) @@ -83,7 +83,7 @@ def test_multiple_sequences(sampler): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - output = sampler.sample(logits, metadata) + output = sampler(logits, metadata) expected = torch.tensor([[1, 2, 5], [3, 4, INVALID_TOKEN_ID]], dtype=torch.int, device=logits.device) @@ -98,7 +98,7 @@ def test_single_token_sequence(sampler): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - output = sampler.sample(logits, metadata) + output = sampler(logits, metadata) expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device) assert torch.equal(output.sampled_token_ids, expected) @@ -111,7 +111,7 @@ def test_empty_sequence(sampler): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - output = sampler.sample(logits, metadata) + output = sampler(logits, metadata) expected = torch.tensor([[5]], dtype=torch.int, device=logits.device) assert torch.equal(output.sampled_token_ids, expected) @@ -124,7 +124,7 @@ def test_multiple_mismatches(sampler): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - output = sampler.sample(logits, metadata) + output = sampler(logits, metadata) expected = torch.tensor([[1, 2, 7, INVALID_TOKEN_ID], [4, 8, INVALID_TOKEN_ID, INVALID_TOKEN_ID]], dtype=torch.int, @@ -145,7 +145,7 @@ def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens) - output = sampler.sample(logits, metadata) + output = sampler(logits, metadata) expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device) @@ -161,7 +161,7 @@ def test_logits_shape_handling(sampler): metadata = create_sampling_metadata(spec_tokens) logits = create_logits_tensor(output_tokens, vocab_size) - output = sampler.sample(logits, metadata) + output = sampler(logits, metadata) expected = torch.tensor([[1, 2, 3]], dtype=torch.int, device=logits.device) assert torch.equal(output.sampled_token_ids, expected) assert logits.shape[-1] == vocab_size From 8929ad19755fcdbd98fee10f804792d3637edd0c Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 15 Feb 2025 01:43:53 +0000 Subject: [PATCH 66/75] fix comments --- vllm/v1/core/kv_cache_manager.py | 13 ++++++--- vllm/v1/core/scheduler.py | 32 +++++++++----------- vllm/v1/request.py | 2 +- vllm/v1/sample/rejection_sampler.py | 22 +++++++++----- vllm/v1/worker/gpu_model_runner.py | 45 +++++++++++++++-------------- 5 files changed, 62 insertions(+), 52 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 8a2940f30def..017e625dcdba 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -84,7 +84,9 @@ def __init__( # {req_id: The number of cached blocks for this given request} # This is used to track the number of cached blocks for each request. - self.cached_block_num: Dict[str, int] = defaultdict(int) + # This is only used to track the RUNNING requests, we do not track the + # data for reempted ones. + self.num_cached_block: Dict[str, int] = defaultdict(int) self.prefix_cache_stats = PrefixCacheStats() @property @@ -244,7 +246,10 @@ def allocate_slots( if not self.enable_caching: return new_blocks - num_cached_blocks = self.cached_block_num[request.request_id] + num_cached_blocks = self.num_cached_block[request.request_id] + # Speculated tokens might be rejected in the future, so we does + # not cache any speculated tokens. We only cache blocks with + # generated (accepted) tokens. num_full_blocks_after_append = (num_computed_tokens + num_tokens - len( request.spec_token_ids)) // self.block_size new_full_blocks = req_blocks[ @@ -258,7 +263,7 @@ def allocate_slots( full_blocks=new_full_blocks, prev_block=(req_blocks[num_cached_blocks - 1] if num_cached_blocks > 0 else None)) - self.cached_block_num[ + self.num_cached_block[ request.request_id] = num_full_blocks_after_append return new_blocks @@ -283,7 +288,7 @@ def free(self, request: Request) -> None: if block.ref_cnt == 0: self.free_block_queue.append(block) - self.cached_block_num.pop(request.request_id, None) + self.num_cached_block.pop(request.request_id, None) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 88ad399f8628..b9e0a2cca744 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -116,9 +116,7 @@ def schedule(self) -> "SchedulerOutput": # Encoder-related. scheduled_encoder_inputs: Dict[str, List[int]] = {} encoder_budget = self.max_num_encoder_input_tokens - - # Spec Decode-related. spec_decode: if any request in - # the scheduled batch uses speculative decoding. + # Spec decode-related. scheduled_spec_decode_tokens: Dict[str, List[int]] = {} scheduled_timestamp = time.monotonic() @@ -480,23 +478,17 @@ def update_from_output( request.num_computed_tokens += num_tokens_scheduled assert request.num_computed_tokens <= request.num_tokens else: - # num_computed_tokens_step is the number of tokens computed - # in the current step. - # num_computed_tokens_step = - # num_scheduled_tokens - num_tokens_rejected, - # where num_tokens_rejected = + # num_computed_tokens_step represents the number of tokens + # processed in the current step, considering scheduled + # tokens and rejections. + # It is calculated as: + # num_computed_tokens_step = num_scheduled_tokens - + # num_tokens_rejected, + # where num_tokens_rejected is given by: # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). - # We use this way of calculating num_computed_tokens_step - # because of chunked prefill. In chunked prefill, number of - # computed tokens is not equal to the number of - # generated/sampled tokens. Here, len(scheduled_spec_token_ids) - # + 1 is the maximum number of tokens generated in the current - # step, len(scheduled_spec_token_ids) + 1 - - # len(generated_token_ids) is the number of tokens rejected - # in the current step. scheduled_spec_token_ids = ( - scheduler_output.scheduled_spec_decode_tokens.get( - req_id, [])) + scheduler_output.scheduled_spec_decode_tokens[req_id]) + num_computed_tokens_step = num_scheduled_tokens[req_id] - ( len(scheduled_spec_token_ids) + 1 - len(generated_token_ids)) @@ -516,7 +508,8 @@ def update_from_output( request, input_id) if request.num_computed_tokens >= request.num_tokens: - # We assume all spec tokens are verified + # Clear the spec tokens as the request has generated + # a new token. Here, We assume all spec tokens are verified # if we perform speculative decoding for this request. # Therefore, we can clear all spec tokens after # the generation step. @@ -534,6 +527,7 @@ def update_from_output( request.append_output_token_ids(output_token_id) new_token_ids.append(output_token_id) + # Check for stop and update request state. # This must be called before we make the EngineCoreOutput. stopped = self._check_stop(request) if stopped: diff --git a/vllm/v1/request.py b/vllm/v1/request.py index afc5fb9aef16..a1bcc2d0393c 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -114,7 +114,7 @@ def append_spec_token_ids( self.spec_token_ids.extend(token_ids) def clear_spec_tokens(self) -> None: - self.spec_token_ids = [] + self.spec_token_ids.clear() @property def num_tokens(self) -> int: diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 8aa9cb291a90..6a0bbe7b216f 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -26,21 +26,26 @@ def forward(self, logits: torch.Tensor, "Only greedy sampling is supported by rejection sampler.") if is_flashinfer_available: + logger.info("User FlashInfer for rejection sampling.") return RejectionSampler.flashinfer_sample(logits, sampling_metadata) else: - return RejectionSampler.greedy_sample_ref(logits, - sampling_metadata) + logger.warning( + "FlashInfer is not available. Falling back to the PyTorch-" + "native implementation of rejection sampling.") + return RejectionSampler.greedy_sample_native( + logits, sampling_metadata) @staticmethod def flashinfer_sample( logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> SamplerOutput: + # NOTE: The following input preparationg can be moved + # to the model runner with a persistent manner for better + # performance. spec_token_ids = sampling_metadata.spec_token_ids - spec_lengths = torch.tensor([len(s) for s in spec_token_ids], - device="cpu") - max_spec_len = torch.max(spec_lengths).item() - batch_size = len(spec_lengths) + max_spec_len = max(len(s) for s in spec_token_ids) + batch_size = len(spec_token_ids) draft_token_ids = torch.full((batch_size, max_spec_len), INVALID_TOKEN_ID, device="cpu", @@ -51,6 +56,7 @@ def flashinfer_sample( device=logits.device, dtype=torch.long) + # TODO: Vectorize the following loop for better performance. start_loc = 0 for i in range(batch_size): num_spec_tokens = len(spec_token_ids[i]) @@ -63,6 +69,7 @@ def flashinfer_sample( start_loc = end_loc vocab_size = logits.size(-1) + # NOTE: CPU <-> GPU synchronization happens here. draft_token_ids = draft_token_ids.to(logits.device) draft_probs = RejectionSampler._create_greedy_token_probs( draft_token_ids, vocab_size, logits.device) @@ -81,8 +88,9 @@ def flashinfer_sample( return SamplerOutput(sampled_token_ids=sampled_token_ids, logprobs_tensors=None) + # TODO: The following method can be optimized for better performance. @staticmethod - def greedy_sample_ref( + def greedy_sample_native( logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> SamplerOutput: spec_lens = [len(x) for x in sampling_metadata.spec_token_ids] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 36d9b091bb73..552d4e51d2e8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -435,8 +435,6 @@ def _prepare_inputs( use_spec_decode = len(all_spec_token_ids) > 0 if use_spec_decode: - # Currently, we assume all speculated tokens are verified. - # We don't support partial verification. # 1. Write spec_token_ids to input batch. # Step 1. Get req indices that perform spec decode and repeat @@ -467,7 +465,7 @@ def _prepare_inputs( cumsums_spec_offsets = torch.from_numpy(cumsums_spec_offsets).to( torch.int64) all_spec_token_ids = torch.tensor(all_spec_token_ids, - device=self.input_ids_cpu.device, + device="cpu", dtype=self.input_ids_cpu.dtype) # Step 2. Write spec token ids to input_ids_cpu. @@ -759,7 +757,7 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): def _prepare_sampling( self, batch_changed: bool, - req_id_spec_token_ids: Dict[str, List[int]], + req_to_spec_token_ids: Dict[str, List[int]], ) -> SamplingMetadata: # Create the sampling metadata. req_id_output_token_ids: Dict[str, List[int]] = \ @@ -767,7 +765,7 @@ def _prepare_sampling( for req_id, req in self.requests.items()} sampling_metadata = self.input_batch.make_sampling_metadata( - req_id_output_token_ids, req_id_spec_token_ids, not batch_changed) + req_id_output_token_ids, req_to_spec_token_ids, not batch_changed) return sampling_metadata def _execute_encoder(self, scheduler_output: "SchedulerOutput"): @@ -937,7 +935,6 @@ def execute_model( sampling_metadata=sampling_metadata, ) - sampled_token_ids = sampler_output.sampled_token_ids # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. num_reqs = self.input_batch.num_reqs @@ -948,8 +945,6 @@ def execute_model( seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) if seq_len >= req_state.num_tokens: - # We don't rewind the generator state for requests now - # because spec decode only supports greedy decoding for now. request_seq_lens.append((i, req_state, seq_len)) else: # Ignore the sampled token from the partial request. @@ -965,9 +960,10 @@ def execute_model( self.input_batch.req_ids[:num_reqs]), "req_ids contains None" req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) - logprobs_tensors = sampler_output.logprobs_tensors # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. + sampled_token_ids = sampler_output.sampled_token_ids + logprobs_tensors = sampler_output.logprobs_tensors logprobs_lists = logprobs_tensors.tolists() \ if logprobs_tensors is not None else None @@ -978,18 +974,25 @@ def execute_model( ) # Update batch with the valid generated tokens. - valid_mask = sampled_token_ids != INVALID_TOKEN_ID - gen_lens = valid_mask.sum(dim=1).tolist() - valid_sampled_token_ids = [ - seq.tolist() - for seq in sampled_token_ids[valid_mask].split(gen_lens) - ] - self.input_batch.num_tokens[:num_reqs] += gen_lens - for i, req_state, seq_len in request_seq_lens: - target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1) - self.input_batch.token_ids_cpu[ - i, target_slice] = valid_sampled_token_ids[i] - req_state.output_token_ids.extend(valid_sampled_token_ids[i]) + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + for i, req_state, seq_len in request_seq_lens: + token_id = sampled_token_ids[i] + self.input_batch.token_ids_cpu[i, seq_len] = token_id + req_state.output_token_ids.append(token_id) + else: + valid_mask = sampled_token_ids != INVALID_TOKEN_ID + gen_lens = valid_mask.sum(dim=1).tolist() + valid_sampled_token_ids = [ + seq.tolist() + for seq in sampled_token_ids[valid_mask].split(gen_lens) + ] + self.input_batch.num_tokens[:num_reqs] += gen_lens + for i, req_state, seq_len in request_seq_lens: + target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1) + self.input_batch.token_ids_cpu[ + i, target_slice] = valid_sampled_token_ids[i] + req_state.output_token_ids.extend(valid_sampled_token_ids[i]) model_runner_output = ModelRunnerOutput( req_ids=req_ids, From 65bb67fcacce2bee8a0886da0dd9a213022b96a6 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 15 Feb 2025 02:12:08 +0000 Subject: [PATCH 67/75] minor fix --- vllm/v1/worker/gpu_model_runner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 552d4e51d2e8..0034581a8dc5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -962,7 +962,6 @@ def execute_model( # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. - sampled_token_ids = sampler_output.sampled_token_ids logprobs_tensors = sampler_output.logprobs_tensors logprobs_lists = logprobs_tensors.tolists() \ if logprobs_tensors is not None else None @@ -974,10 +973,12 @@ def execute_model( ) # Update batch with the valid generated tokens. + sampled_token_ids = sampler_output.sampled_token_ids max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: + valid_sampled_token_ids = sampled_token_ids.tolist() for i, req_state, seq_len in request_seq_lens: - token_id = sampled_token_ids[i] + token_id = sampled_token_ids[i][0] self.input_batch.token_ids_cpu[i, seq_len] = token_id req_state.output_token_ids.append(token_id) else: From 992aab82f6bc0d0d48b7d94d55f5a86734431ec9 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 15 Feb 2025 02:13:23 +0000 Subject: [PATCH 68/75] make test more deterministic --- tests/v1/e2e/test_ngram_spec_decode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/e2e/test_ngram_spec_decode.py b/tests/v1/e2e/test_ngram_spec_decode.py index a45ed173752f..150caa150a59 100644 --- a/tests/v1/e2e/test_ngram_spec_decode.py +++ b/tests/v1/e2e/test_ngram_spec_decode.py @@ -8,7 +8,7 @@ def test_prompts(): return [ "Can you repeat the sentence ten times, this is a sentence.", - "Can you repeat the sentence ten times, the future of AI is right.", + "Can you repeat the sentence ten times, this is a test.", ] From e298bb34ab0baaafd397d4dc1830dabbbf6af4c5 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 15 Feb 2025 02:16:29 +0000 Subject: [PATCH 69/75] merge conflict --- vllm/v1/worker/tpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index b64581bf5f42..8635ffce7027 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -695,7 +695,7 @@ def execute_model( model_runner_output = ModelRunnerOutput( req_ids=all_req_ids, req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=sampled_token_ids, + sampled_token_ids=[[token_id] for token_id in sampled_token_ids], logprobs=None, prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type] ) From 4e015ae7ffc7c3e82398470ee173a4d2e4ff6784 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Fri, 14 Feb 2025 21:48:34 -0800 Subject: [PATCH 70/75] fix --- vllm/v1/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d2ab5c1fbfaf..13f086e03b43 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -987,7 +987,7 @@ def execute_model( if max_gen_len == 1: valid_sampled_token_ids = sampled_token_ids.tolist() for i, req_state, seq_len in request_seq_lens: - token_id = sampled_token_ids[i][0] + token_id = valid_sampled_token_ids[i][0] self.input_batch.token_ids_cpu[i, seq_len] = token_id req_state.output_token_ids.append(token_id) else: From 5fc52640c56e4a5b1b18e99d71ff021f82aeea7f Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 15 Feb 2025 13:13:12 -0800 Subject: [PATCH 71/75] fix rejection sampler tests --- tests/v1/sample/test_rejection_sampler.py | 44 +++++++++++++---------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 94c40027729f..8bc33e84194c 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -24,25 +24,31 @@ def create_logits_tensor(token_ids: List[int], def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata: - return SamplingMetadata(temperature=0.0, - all_greedy=True, - all_random=False, - rejection_sampling=True, - spec_token_ids=spec_tokens, - top_p=None, - top_k=None, - no_top_p=False, - no_top_k=False, - generators={}, - max_num_logprobs=0, - no_penalties=False, - prompt_token_ids=None, - frequency_penalties=torch.tensor([]), - presence_penalties=torch.tensor([]), - repetition_penalties=torch.tensor([]), - output_token_ids=[], - min_tokens=[], - stop_token_ids=[]) + batch_size = len(spec_tokens) + return SamplingMetadata( + temperature=0.0, + all_greedy=True, + all_random=False, + rejection_sampling=True, + spec_token_ids=spec_tokens, + top_p=None, + top_k=None, + no_top_p=False, + no_top_k=False, + min_p=torch.empty(batch_size, ), + no_min_p=True, + generators={}, + max_num_logprobs=0, + no_penalties=False, + prompt_token_ids=None, + frequency_penalties=torch.tensor([]), + presence_penalties=torch.tensor([]), + repetition_penalties=torch.tensor([]), + output_token_ids=[], + min_tokens=[], + stop_token_ids=[], + logit_bias=[None] * batch_size, + ) def test_perfect_match(sampler): From b56a8e4c21a1cd69a6812495220fd0a61134aa38 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 15 Feb 2025 13:42:33 -0800 Subject: [PATCH 72/75] fix num_token --- vllm/v1/worker/gpu_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 13f086e03b43..104c93434b20 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -990,6 +990,7 @@ def execute_model( token_id = valid_sampled_token_ids[i][0] self.input_batch.token_ids_cpu[i, seq_len] = token_id req_state.output_token_ids.append(token_id) + self.input_batch.num_tokens[i] += 1 else: valid_mask = sampled_token_ids != INVALID_TOKEN_ID gen_lens = valid_mask.sum(dim=1).tolist() From 2cbf57e975b31f103a5a039f843a8e418ff22971 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 15 Feb 2025 13:55:32 -0800 Subject: [PATCH 73/75] fix scheduler test --- tests/v1/core/test_scheduler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 1266ad1824c5..fdf3d662c665 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -232,6 +232,7 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) + scheduler.scheduled_req_ids.add(req.request_id) scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], @@ -280,6 +281,7 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) + scheduler.scheduled_req_ids.add(req.request_id) scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], @@ -326,6 +328,7 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) + scheduler.scheduled_req_ids.add(req.request_id) scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], @@ -372,6 +375,7 @@ def test_stop_via_update_from_output(): requests[0].num_computed_tokens = requests[0].num_tokens scheduler.requests[requests[0].request_id] = requests[0] scheduler.running.append(requests[0]) + scheduler.scheduled_req_ids.add(requests[0].request_id) scheduler_output = SchedulerOutput( scheduled_new_reqs=[], From 29d305475cec5ccde0b55dc80e1e328019d1cd46 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 15 Feb 2025 13:56:51 -0800 Subject: [PATCH 74/75] fix scheduler test, minor --- tests/v1/core/test_scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index fdf3d662c665..e39a7f9f40bd 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -433,7 +433,7 @@ def test_schedule_concurrent_batches(): model_runner_output = ModelRunnerOutput( req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, - sampled_token_ids=[0], + sampled_token_ids=[[0]], logprobs=None, prompt_logprobs_dict={}, ) @@ -449,7 +449,7 @@ def test_schedule_concurrent_batches(): model_runner_output = ModelRunnerOutput( req_ids=[requests[1].request_id], req_id_to_index={requests[1].request_id: 0}, - sampled_token_ids=[0], + sampled_token_ids=[[0]], logprobs=None, prompt_logprobs_dict={}, ) From 2dc79093c55b859b13f69451787df5a1bae9d597 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 15 Feb 2025 14:14:13 -0800 Subject: [PATCH 75/75] fix gpu model runner --- tests/v1/worker/test_gpu_model_runner.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index f5219b676a83..576d906fa749 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -66,6 +66,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: scheduled_cached_reqs=[], num_scheduled_tokens=num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), @@ -109,6 +110,7 @@ def test_update_states_request_finished(model_runner): scheduled_cached_reqs=[], num_scheduled_tokens={}, total_num_scheduled_tokens=0, + scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids={req_id}, @@ -137,6 +139,7 @@ def test_update_states_request_resumed(model_runner): scheduled_cached_reqs=[], num_scheduled_tokens={}, total_num_scheduled_tokens=0, + scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids={}, @@ -160,6 +163,7 @@ def test_update_states_request_resumed(model_runner): scheduled_cached_reqs=[cached_req_data], num_scheduled_tokens={req_id: 1}, total_num_scheduled_tokens=1, + scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), @@ -188,6 +192,7 @@ def test_update_states_no_changes(model_runner): scheduled_cached_reqs=[], num_scheduled_tokens={req_id: 1}, total_num_scheduled_tokens=1, + scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), @@ -220,6 +225,7 @@ def test_update_states_request_unscheduled(model_runner): scheduled_cached_reqs=[], num_scheduled_tokens={req_ids[0]: 1}, total_num_scheduled_tokens=1, + scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(),