Skip to content

Commit 68a6178

Browse files
jiangpeng36jiangpeng36Ronald1995
authored andcommitted
[Perf][V1] Fully overlap model execution (vllm-project#2783)
This PR is based on top of [#23569](vllm-project/vllm#23569) and [#24219](vllm-project/vllm#24219). ### What this PR does / why we need it? This PR allows the model runner to function asynchronously when using async scheduling. This allows full overlap of the cpu operations (including prepare_inputs) and the model forward pass. This diff is functional and does not support speculative decoding, PP, or guided decoding. Expected speedup is 5-10% over the current async scheduling. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? server ``` python -m vllm.entrypoints.openai.api_server --model=Qwen3-32B\ --trust-remote-code --enforce-eager \ --distributed-executor-backend=mp \ -tp=4 \ --port 8006 \ --max-model-len 32000 \ --block-size 128 \ --gpu-memory-utilization 0.99 ``` client ``` python $TEST_PY --backend vllm --trust-remote-code --model Qwen3-32B \ --dataset-name random --random-input-len 2048 --random-output-len 2048 \ --ignore-eos\ --num-prompts 48 --max-concurrency 48 --request-rate inf --temperature 0 \ --metric-percentiles 90 --base-url http://localhost:8006 --save-result \ --result-dir $PROFILER_DIR ``` benchmark test based on Qwen3-32B TPOT result: ||forward async| scheduler async |sync| |-|-|-|-| |avg|41.73|41.86|44.20| |improve0|0.3%|0|0| |improve1|5.58%|0|0| benchmark test based on Qwen2___5-VL-7B-Instruct TPOT result: ||forward async|sync| |-|-|-| |avg|23.22|29.16| |improve|20.3%|0| - vLLM version: main - vLLM main: vllm-project/vllm@e93f4cc Signed-off-by: jiangpeng36 <[email protected]> Signed-off-by: Ronald1995 <[email protected]> Co-authored-by: jiangpeng36 <[email protected]> Co-authored-by: Ronald1995 <[email protected]> Signed-off-by: Yizhou Liu <[email protected]>
1 parent 01395c3 commit 68a6178

File tree

4 files changed

+227
-29
lines changed

4 files changed

+227
-29
lines changed

tests/e2e/singlecard/test_ascend_scheduler.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import pytest
4+
from vllm import SamplingParams
45

56
from tests.e2e.conftest import VllmRunner
67
from tests.e2e.model_utils import check_outputs_equal
@@ -86,3 +87,25 @@ def test_chunked_prefill_with_ascend_scheduler(
8687
name_0="vllm_output",
8788
name_1="chunked_prefill_output",
8889
)
90+
91+
92+
def test_async_scheduling() -> None:
93+
prompts = [
94+
"Hello, my name is",
95+
"The president of the United States is",
96+
"The capital of France is",
97+
"The future of AI is",
98+
] * 10
99+
sampling_params = SamplingParams(temperature=0.2,
100+
max_tokens=10,
101+
stop_token_ids=None)
102+
103+
with VllmRunner(
104+
"Qwen/Qwen2.5-0.5B-Instruct",
105+
max_model_len=4096,
106+
max_num_seqs=50,
107+
dtype="bfloat16",
108+
gpu_memory_utilization=0.9,
109+
async_scheduling=True,
110+
) as vllm_model:
111+
vllm_model.generate(prompts, sampling_params=sampling_params)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 194 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@
6363
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
6464
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
6565
KVCacheSpec)
66-
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
67-
LogprobsTensors, ModelRunnerOutput)
66+
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
67+
DraftTokenIds, LogprobsTensors, ModelRunnerOutput)
6868
from vllm.v1.pool.metadata import PoolingMetadata
6969
from vllm.v1.sample.metadata import SamplingMetadata
7070
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@@ -156,6 +156,53 @@ def graph_capture(device: torch.device):
156156
yield graph_capture_context
157157

158158

159+
# Wrapper for ModelRunnerOutput to support overlapped execution.
160+
class AsyncNPUModelRunnerOutput(AsyncModelRunnerOutput):
161+
162+
def __init__(
163+
self,
164+
model_runner_output: ModelRunnerOutput,
165+
sampled_token_ids: torch.Tensor,
166+
invalid_req_indices: list[int],
167+
async_output_copy_stream: torch.npu.Stream,
168+
):
169+
self._model_runner_output = model_runner_output
170+
self._invalid_req_indices = invalid_req_indices
171+
172+
# Event on the copy stream so we can synchronize the non-blocking copy.
173+
self._async_copy_ready_event = torch.npu.Event()
174+
175+
# Keep a reference to the device tensor to avoid it being
176+
# deallocated until we finish copying it to the host.
177+
self._sampled_token_ids = sampled_token_ids
178+
179+
# Initiate the copy on a separate stream, but do not synchronize it.
180+
default_stream = torch.npu.current_stream()
181+
with torch.npu.stream(async_output_copy_stream):
182+
async_output_copy_stream.wait_stream(default_stream)
183+
self._sampled_token_ids_cpu = self._sampled_token_ids.to(
184+
'cpu', non_blocking=True)
185+
self._async_copy_ready_event.record()
186+
187+
def get_output(self) -> ModelRunnerOutput:
188+
"""Copy the device tensors to the host and return a ModelRunnerOutput.
189+
190+
This function blocks until the copy is finished.
191+
"""
192+
self._async_copy_ready_event.synchronize()
193+
194+
# Release the device tensor once the copy has completed
195+
del self._sampled_token_ids
196+
197+
valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist()
198+
for i in self._invalid_req_indices:
199+
valid_sampled_token_ids[i].clear()
200+
201+
output = self._model_runner_output
202+
output.sampled_token_ids = valid_sampled_token_ids
203+
return output
204+
205+
159206
class NPUModelRunner(LoRAModelRunnerMixin):
160207

161208
def __init__(self, vllm_config: VllmConfig, device: torch.device):
@@ -358,6 +405,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
358405
device=self.device,
359406
)
360407

408+
self.use_async_scheduling = self.scheduler_config.async_scheduling
409+
self.async_output_copy_stream = torch.npu.Stream() if \
410+
self.use_async_scheduling else None
411+
361412
def _use_aclgraph(self) -> bool:
362413
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
363414

@@ -845,6 +896,76 @@ def _get_cumsum_and_arange(
845896

846897
return cu_num_tokens, arange
847898

899+
def _prepare_input_ids(self, total_num_scheduled_tokens: int,
900+
cu_num_tokens: np.ndarray) -> None:
901+
"""Prepare the input IDs for the current batch.
902+
903+
Carefully handles the `prev_sampled_token_ids` which can be cached
904+
from the previous engine iteration, in which case those tokens on the
905+
NPU need to be copied into the corresponding slots into input_ids."""
906+
907+
if self.input_batch.prev_sampled_token_ids is None:
908+
# Normal scheduling case
909+
self.input_ids[:total_num_scheduled_tokens].copy_(
910+
self.input_ids_cpu[:total_num_scheduled_tokens],
911+
non_blocking=True)
912+
return
913+
914+
# Async scheduling case, where some decode requests from the previous
915+
# iteration won't have entries in input_ids_cpu and need to be copied
916+
# on the NPU from prev_sampled_token_ids.
917+
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
918+
assert prev_req_id_to_index is not None
919+
flattened_indices = []
920+
prev_common_req_indices = []
921+
indices_match = True
922+
max_flattened_index = -1
923+
for req_id, cur_index in self.input_batch.req_id_to_index.items():
924+
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
925+
prev_common_req_indices.append(prev_index)
926+
# We need to compute the flattened input_ids index of the
927+
# last token in each common request.
928+
flattened_index = cu_num_tokens[cur_index].item() - 1
929+
flattened_indices.append(flattened_index)
930+
indices_match &= (prev_index == flattened_index)
931+
max_flattened_index = max(max_flattened_index, flattened_index)
932+
num_commmon_tokens = len(flattened_indices)
933+
if num_commmon_tokens < total_num_scheduled_tokens:
934+
# If not all requests are decodes from the last iteration,
935+
# We need to copy the input_ids_cpu to the NPU first.
936+
self.input_ids[:total_num_scheduled_tokens].copy_(
937+
self.input_ids_cpu[:total_num_scheduled_tokens],
938+
non_blocking=True)
939+
if num_commmon_tokens == 0:
940+
# No requests in common with the previous iteration
941+
# So input_ids_cpu will have all the input ids.
942+
return
943+
if indices_match and max_flattened_index == (num_commmon_tokens - 1):
944+
# Common-case optimization: the batch is unchanged
945+
# and no reordering happened.
946+
# The indices are both the same permutation of 0..N-1 so
947+
# we can copy directly using a single slice.
948+
self.input_ids[:num_commmon_tokens].copy_(
949+
self.input_batch.prev_sampled_token_ids[:num_commmon_tokens,
950+
0],
951+
non_blocking=True)
952+
return
953+
# Upload the index tensors asynchronously
954+
# so the scatter can be non-blocking.
955+
input_ids_index_tensor = torch.tensor(flattened_indices,
956+
dtype=torch.int64,
957+
pin_memory=self.pin_memory).to(
958+
self.device,
959+
non_blocking=True)
960+
prev_common_req_indices_tensor = torch.tensor(
961+
prev_common_req_indices,
962+
dtype=torch.int64,
963+
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
964+
self.input_ids.scatter_(dim=0,
965+
index=input_ids_index_tensor,
966+
src=self.input_batch.prev_sampled_token_ids[
967+
prev_common_req_indices_tensor, 0])
968+
848969
def _prepare_inputs(
849970
self,
850971
scheduler_output: "SchedulerOutput",
@@ -1033,6 +1154,16 @@ def _prepare_inputs(
10331154
if self.vllm_config.model_config.use_mla:
10341155
attn_metadata.num_input_tokens = num_input_tokens
10351156

1157+
# Prepare input_ids
1158+
token_indices = (positions_np +
1159+
req_indices * self.input_batch.token_ids_cpu.shape[1])
1160+
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
1161+
0,
1162+
torch.from_numpy(token_indices),
1163+
out=self.input_ids_cpu[:total_num_scheduled_tokens])
1164+
# Copy the tensors to the NPU.
1165+
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
1166+
10361167
# _prepare_inputs may reorder the batch, so we must gather
10371168
# multi-modal outputs after that to ensure the correct order
10381169
if self.is_multimodal_model:
@@ -1382,11 +1513,11 @@ def _select_moe_comm_method(self, num_tokens: int) -> str:
13821513
2. If expert parallel is enabled, we need to consider the soc version and the
13831514
number of tokens. This is based on the observation that all-gather is more
13841515
efficient than all-to-all when running on A2.
1385-
1516+
13861517
a. For A2, we choose from MC2 and all-gather.
1387-
1518+
13881519
b. For A3, we choose from MC2 and all-to-all.
1389-
1520+
13901521
In both cases, we use MC2 when the number of tokens is smaller than
13911522
a its capacity threshold.
13921523
@@ -1424,7 +1555,7 @@ def execute_model(
14241555
self,
14251556
scheduler_output: "SchedulerOutput",
14261557
intermediate_tensors: Optional[IntermediateTensors] = None,
1427-
) -> Union[ModelRunnerOutput, torch.Tensor]:
1558+
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
14281559
with ProfileExecuteDuration().capture_async("prepare input"):
14291560
self._update_states(scheduler_output)
14301561
if not scheduler_output.total_num_scheduled_tokens:
@@ -1580,6 +1711,12 @@ def execute_model(
15801711
generator.set_offset(generator.get_offset() - 4)
15811712
discard_sampled_tokens_req_indices.append(i)
15821713

1714+
# Copy some objects so they don't get modified after returning.
1715+
# This is important when using async scheduling.
1716+
req_ids_output_copy = self.input_batch.req_ids.copy()
1717+
req_id_to_index_output_copy = \
1718+
self.input_batch.req_id_to_index.copy()
1719+
15831720
# NOTE: NPU -> CPU Sync happens here.
15841721
# Move as many CPU operations as possible before this sync point.
15851722
logprobs_tensors = sampler_output.logprobs_tensors
@@ -1592,27 +1729,52 @@ def execute_model(
15921729
scheduler_output,
15931730
)
15941731

1595-
# Get the valid generated tokens.
1732+
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
15961733
sampled_token_ids = sampler_output.sampled_token_ids
1597-
max_gen_len = sampled_token_ids.shape[-1]
1598-
if max_gen_len == 1:
1599-
# No spec decode tokens.
1600-
valid_sampled_token_ids = sampled_token_ids.tolist()
1734+
if not self.use_async_scheduling:
1735+
# Get the valid generated tokens.
1736+
max_gen_len = sampled_token_ids.shape[-1]
1737+
if max_gen_len == 1:
1738+
# No spec decode tokens.
1739+
valid_sampled_token_ids = sampled_token_ids.tolist()
1740+
else:
1741+
# Includes spec decode tokens.
1742+
valid_sampled_token_ids = self.rejection_sampler.parse_output(
1743+
sampled_token_ids,
1744+
self.input_batch.vocab_size,
1745+
)
1746+
# Mask out the sampled tokens that should not be sampled.
1747+
for i in discard_sampled_tokens_req_indices:
1748+
valid_sampled_token_ids[i].clear()
16011749
else:
1602-
# Includes spec decode tokens.
1603-
valid_sampled_token_ids = self.rejection_sampler.parse_output(
1604-
sampled_token_ids,
1605-
self.input_batch.vocab_size,
1606-
)
1607-
1608-
for i in discard_sampled_tokens_req_indices:
1609-
valid_sampled_token_ids[i].clear()
1610-
# Cache the sampled tokens in the model runner, so that the schedulerAdd commentMore actions
1750+
valid_sampled_token_ids = []
1751+
invalid_req_indices = list(discard_sampled_tokens_req_indices)
1752+
invalid_req_indices_set = set(invalid_req_indices)
1753+
assert sampled_token_ids.shape[-1] == 1
1754+
1755+
# Cache the sampled tokens on the NPU and avoid CPU sync.
1756+
# These will be copied into input_ids in the next step
1757+
# when preparing inputs.
1758+
self.input_batch.prev_sampled_token_ids = \
1759+
sampled_token_ids
1760+
self.input_batch.prev_sampled_token_ids_invalid_indices = \
1761+
invalid_req_indices_set
1762+
self.input_batch.prev_req_id_to_index = {
1763+
req_id: i
1764+
for i, req_id in enumerate(self.input_batch.req_ids)
1765+
if i not in invalid_req_indices_set
1766+
}
1767+
# Cache the sampled tokens in the model runner, so that the scheduler
16111768
# doesn't need to send them back.
16121769
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
16131770
# the sampled tokens back, because there's no direct communication
16141771
# between the first-stage worker and the last-stage worker.
1615-
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
1772+
for req_idx in range(num_sampled_tokens):
1773+
if self.use_async_scheduling:
1774+
sampled_ids = [-1] * 1 if \
1775+
req_idx not in invalid_req_indices_set else None
1776+
else:
1777+
sampled_ids = valid_sampled_token_ids[req_idx]
16161778
if not sampled_ids:
16171779
continue
16181780

@@ -1650,8 +1812,8 @@ def execute_model(
16501812
extra_args = ({"kv_connector_output": kv_connector_output})
16511813

16521814
model_runner_output = ModelRunnerOutput(
1653-
req_ids=self.input_batch.req_ids,
1654-
req_id_to_index=self.input_batch.req_id_to_index,
1815+
req_ids=req_ids_output_copy,
1816+
req_id_to_index=req_id_to_index_output_copy,
16551817
sampled_token_ids=valid_sampled_token_ids,
16561818
logprobs=logprobs_lists,
16571819
prompt_logprobs_dict=prompt_logprobs_dict,
@@ -1669,7 +1831,15 @@ def execute_model(
16691831
logger.info("Profile execute duration [%s]:%s", captured_name,
16701832
" ".join(dr_str))
16711833

1672-
return model_runner_output
1834+
if not self.use_async_scheduling:
1835+
return model_runner_output
1836+
1837+
return AsyncNPUModelRunnerOutput(
1838+
model_runner_output=model_runner_output,
1839+
sampled_token_ids=sampled_token_ids,
1840+
invalid_req_indices=invalid_req_indices,
1841+
async_output_copy_stream=self.async_output_copy_stream,
1842+
)
16731843

16741844
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
16751845
if self._draft_token_ids is None:

vllm_ascend/worker/npu_input_batch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,11 @@ def __init__(
263263

264264
self.pooling_params: dict[str, PoolingParams] = {}
265265

266+
# Cached reference to the GPU tensor of previously sampled tokens
267+
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
268+
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
269+
self.prev_req_id_to_index: Optional[dict[str, int]] = None
270+
266271
@property
267272
def req_ids(self) -> list[str]:
268273
# None elements should only be present transiently

vllm_ascend/worker/worker_v1.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#
1919

2020
import copy
21-
from typing import Optional
21+
from typing import Optional, Union
2222

2323
import torch
2424
import torch.nn as nn
@@ -38,8 +38,8 @@
3838
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes
3939
from vllm.v1.core.sched.output import SchedulerOutput
4040
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
41-
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
42-
ModelRunnerOutput)
41+
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
42+
DraftTokenIds, ModelRunnerOutput)
4343
from vllm.v1.worker.worker_base import WorkerBase
4444

4545
from vllm_ascend.ascend_config import init_ascend_config
@@ -191,7 +191,7 @@ def determine_available_memory(self) -> int:
191191
def execute_model(
192192
self,
193193
scheduler_output: "SchedulerOutput",
194-
) -> Optional[ModelRunnerOutput]:
194+
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
195195
intermediate_tensors = None
196196
if not get_pp_group().is_first_rank:
197197
intermediate_tensors = IntermediateTensors(
@@ -220,7 +220,7 @@ def execute_model(
220220
new_output.kv_connector_output = kv_connector_output
221221
return new_output
222222

223-
assert isinstance(output, ModelRunnerOutput)
223+
assert isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput))
224224
return output
225225

226226
def load_model(self) -> None:

0 commit comments

Comments
 (0)