Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
95b1b7d
first prototype, working for BS=1
benchislett Aug 24, 2025
5bd9851
wip for batched
benchislett Aug 25, 2025
9a59696
fix bs > 1
benchislett Aug 25, 2025
a24d715
add back removed code
benchislett Aug 25, 2025
51b6169
minor perf optimization
benchislett Aug 25, 2025
2832e37
improvements
benchislett Aug 26, 2025
bd331b4
remove old prints
benchislett Aug 26, 2025
dfa5ca9
Merge branch 'main' into overlap-model-execution
benchislett Aug 26, 2025
c118525
fix precommit
benchislett Aug 26, 2025
43b4f17
Merge branch 'main' into overlap-model-execution
benchislett Aug 27, 2025
752ccf9
misc cleanup
benchislett Aug 27, 2025
9f28326
refactor prepare_input_ids
benchislett Aug 27, 2025
15d7b31
tiny refactor to reorder some ops
benchislett Aug 27, 2025
b351a56
Merge branch 'main' into overlap-model-execution
benchislett Sep 2, 2025
5df3ae8
refactor async model runner output
benchislett Sep 2, 2025
efcc3ee
tiny cleanup
benchislett Sep 2, 2025
b4611f4
Merge branch 'main' into overlap-model-execution
benchislett Sep 2, 2025
6c025bb
remove torch from multiproc_executor
benchislett Sep 2, 2025
bc99a79
refactor async output in multiproc executor
benchislett Sep 3, 2025
2ffa123
cleanup
benchislett Sep 3, 2025
7ae3166
improve async gpu model runner output structure
benchislett Sep 3, 2025
75c109d
use cuda event to sync copy stream
benchislett Sep 3, 2025
3f9d46b
Merge branch 'main' into overlap-model-execution
benchislett Sep 3, 2025
ff5bc7a
minor refactor for readability
benchislett Sep 4, 2025
6a44032
more minor refactor
benchislett Sep 4, 2025
b411981
Merge branch 'main' into overlap-model-execution
benchislett Sep 4, 2025
0d23f0e
refactor prepare_input_ids for fewer cpu ops
benchislett Sep 4, 2025
54feea9
restructure multiproc output handling to isolate effects on non-async…
benchislett Sep 5, 2025
4bddae2
Merge branch 'main' into overlap-model-execution
benchislett Sep 5, 2025
70f4921
Merge branch 'main' into overlap-model-execution
benchislett Sep 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions vllm/v1/executor/multiproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import multiprocessing
import os
import pickle
import queue
import signal
import threading
import time
Expand All @@ -18,6 +19,7 @@
from typing import Any, Callable, Optional, Union, cast

import cloudpickle
import torch

import vllm.envs as envs
from vllm.config import VllmConfig
Expand All @@ -33,7 +35,8 @@
get_loopback_ip, get_mp_context, get_open_port,
set_process_title)
from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds,
ModelRunnerOutput)
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)
Expand Down Expand Up @@ -586,6 +589,44 @@ class ResponseStatus(Enum):

def worker_busy_loop(self):
"""Main busy loop for Multiprocessing Workers"""

def process_output(output: Any, worker_response_mq: MessageQueue,
copy_stream: torch.cuda.Stream) -> None:
if isinstance(output, AsyncModelRunnerOutput):
# Serialize the sampled token ids before sending the output
default_stream = torch.cuda.current_stream()
with torch.cuda.stream(copy_stream):
copy_stream.wait_stream(default_stream)
sampled_token_ids_list = output.sampled_token_ids_tensor.to(
'cpu', non_blocking=True)
copy_stream.synchronize()
sampled_token_ids_list = sampled_token_ids_list.tolist()
for i in output.invalid_req_indices:
sampled_token_ids_list[i].clear()
output = output.model_runner_output
output.sampled_token_ids = sampled_token_ids_list

worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.SUCCESS, output))
return

def _output_processor_loop(input_queue: queue.Queue,
worker_response_mq: MessageQueue,
copy_stream: torch.cuda.Stream):
while True:
output = input_queue.get()
process_output(output, worker_response_mq, copy_stream)

output_queue: queue.Queue = queue.Queue()
copy_stream_ = torch.cuda.Stream()
output_processor_thread = Thread(target=_output_processor_loop,
args=(output_queue,
self.worker_response_mq,
copy_stream_),
daemon=True,
name="WorkerOutputProcessor")
output_processor_thread.start()

while True:
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue()

Expand All @@ -608,5 +649,4 @@ def worker_busy_loop(self):
continue

if output_rank is None or self.rank == output_rank:
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.SUCCESS, output))
output_queue.put(output)
13 changes: 13 additions & 0 deletions vllm/v1/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,19 @@ class ModelRunnerOutput:
num_nans_in_logits: Optional[dict[str, int]] = None


# Subclass of ModelRunnerOutput for async scheduling.
# Contains GPU tensors which must be serialized before sending
# to the scheduler process.
@dataclass
class AsyncModelRunnerOutput:
model_runner_output: ModelRunnerOutput

# [num_reqs, max_num_generated_tokens]
sampled_token_ids_tensor: torch.Tensor

invalid_req_indices: list[int]


@dataclass
class DraftTokenIds:

Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,11 @@ def __init__(

self.pooling_params: dict[str, PoolingParams] = {}

# Cached reference to the GPU tensor of previously sampled tokens
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
self.prev_req_id_to_index: Optional[dict[str, int]] = None

@property
def req_ids(self) -> list[str]:
# None elements should only be present transiently
Expand Down
145 changes: 120 additions & 25 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@
FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
MambaSpec, SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
LogprobsTensors, ModelRunnerOutput)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, LogprobsTensors, ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
from vllm.v1.sample.metadata import SamplingMetadata
Expand Down Expand Up @@ -232,6 +232,8 @@ def __init__(
is_pooling_model=self.is_pooling_model,
)

self.use_async_scheduling = self.scheduler_config.async_scheduling

# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
# The convention is different.
# self.cudagraph_batch_sizes sorts in ascending order.
Expand Down Expand Up @@ -745,7 +747,64 @@ def _prepare_inputs(
max_seq_len = self.seq_lens.np[:num_reqs].max().item()

# Copy the tensors to the GPU.
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
if self.input_batch.prev_sampled_token_ids is not None:
# Async scheduling case, we need to copy the sampled token ids
# from the previous iteration.
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
current_req_id_to_index = self.input_batch.req_id_to_index
assert prev_req_id_to_index is not None
common_req_ids = set(prev_req_id_to_index.keys()).intersection(
set(current_req_id_to_index.keys()))
if common_req_ids:
current_common_req_indices = [
current_req_id_to_index[req_id]
for req_id in common_req_ids
]
prev_common_req_indices = [
prev_req_id_to_index[req_id] for req_id in common_req_ids
]
# We need to compute the flattened input_ids index of the
# last token in each common request.
flattened_indices = [
int(cu_num_tokens[idx]) - 1
for idx in current_common_req_indices
]
if len(flattened_indices) < total_num_scheduled_tokens:
# If not all requests are decodes from the last iteration,
# We need to copy the input_ids_cpu to the GPU first.
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
if flattened_indices == prev_common_req_indices and \
set(flattened_indices) == \
set(range(len(flattened_indices))):
# Common-case optimization: the batch is unchanged
# and no reordering happened.
# The indices are both the same permutation of 0..N-1
self.input_ids.gpu[:len(flattened_indices)].copy_(
self.input_batch.prev_sampled_token_ids[:len(
flattened_indices)].squeeze(1),
non_blocking=True)
else:
# Upload the index tensors asynchronously
# so the scatter can be non-blocking
input_ids_index_tensor = torch.tensor(
flattened_indices,
dtype=torch.int64,
pin_memory=self.pin_memory).to(self.device,
non_blocking=True)
prev_common_req_indices_tensor = torch.tensor(
prev_common_req_indices,
dtype=torch.int64,
pin_memory=self.pin_memory).to(self.device,
non_blocking=True)
self.input_ids.gpu.scatter_(
dim=0,
index=input_ids_index_tensor,
src=self.input_batch.prev_sampled_token_ids[
prev_common_req_indices_tensor].squeeze(1))
else:
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
else:
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
if self.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_(
Expand Down Expand Up @@ -1472,7 +1531,7 @@ def execute_model(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]:
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
Expand Down Expand Up @@ -1692,29 +1751,55 @@ def execute_model(
scheduler_output.num_scheduled_tokens,
)

# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = self._to_list(sampled_token_ids)
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
if not self.use_async_scheduling:
# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = self._to_list(sampled_token_ids)
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids,
self.input_batch.vocab_size,
)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids,
self.input_batch.vocab_size,
)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
valid_sampled_token_ids = []
sampled_token_ids_tensor = sampler_output.sampled_token_ids
invalid_req_indices = list(discard_sampled_tokens_req_indices)
invalid_req_indices_set = set(invalid_req_indices)
assert sampled_token_ids_tensor.shape[-1] == 1

# Cache the sampled tokens on the GPU and avoid CPU sync.
# These will be copied into input_ids in the next step
# when preparing inputs.
self.input_batch.prev_sampled_token_ids = \
sampled_token_ids_tensor
self.input_batch.prev_sampled_token_ids_invalid_indices = \
invalid_req_indices_set
self.input_batch.prev_req_id_to_index = {
req_id: i
for i, req_id in enumerate(self.input_batch.req_ids)
if i not in invalid_req_indices_set
}

# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
req_ids = self.input_batch.req_ids
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
for req_idx in range(num_sampled_tokens):
if self.use_async_scheduling:
sampled_ids = [-1] * 1 if \
req_idx not in invalid_req_indices_set else None
else:
sampled_ids = valid_sampled_token_ids[req_idx]
if not sampled_ids:
continue

Expand All @@ -1725,12 +1810,13 @@ def execute_model(
f"Total number of tokens: {end_idx} > max_model_len: "
f"{self.max_model_len}")

self.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx

req_id = req_ids[req_idx]
req_state = self.requests[req_id]
self.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids
req_state.output_token_ids.extend(sampled_ids)

if self.speculative_config:
Expand All @@ -1748,17 +1834,26 @@ def execute_model(

self.eplb_step()

return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids.copy(),
req_id_to_index=self.input_batch.req_id_to_index.copy(),
sampled_token_ids=valid_sampled_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
kv_connector_output=kv_connector_output,
num_nans_in_logits=num_nans_in_logits,
num_nans_in_logits=num_nans_in_logits.copy(),
)

if self.use_async_scheduling:
return AsyncModelRunnerOutput(
model_runner_output=output,
sampled_token_ids_tensor=sampled_token_ids_tensor,
invalid_req_indices=invalid_req_indices,
)

return output

def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
if self._draft_token_ids is None:
return None
Expand Down
10 changes: 5 additions & 5 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import gc
import os
from contextlib import AbstractContextManager, nullcontext
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Union

import torch
import torch.distributed
Expand All @@ -28,8 +28,8 @@
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
ModelRunnerOutput)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, ModelRunnerOutput)
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase
Expand Down Expand Up @@ -352,7 +352,7 @@ def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
Expand Down Expand Up @@ -383,7 +383,7 @@ def execute_model(
output.kv_connector_output = kv_connector_output
return output

assert isinstance(output, ModelRunnerOutput)
assert isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput))
return output

def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
Expand Down