63
63
from vllm .v1 .cudagraph_dispatcher import CudagraphDispatcher
64
64
from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
65
65
KVCacheSpec )
66
- from vllm .v1 .outputs import (EMPTY_MODEL_RUNNER_OUTPUT , DraftTokenIds ,
67
- LogprobsTensors , ModelRunnerOutput )
66
+ from vllm .v1 .outputs import (EMPTY_MODEL_RUNNER_OUTPUT , AsyncModelRunnerOutput ,
67
+ DraftTokenIds , LogprobsTensors , ModelRunnerOutput )
68
68
from vllm .v1 .pool .metadata import PoolingMetadata
69
69
from vllm .v1 .sample .metadata import SamplingMetadata
70
70
from vllm .v1 .spec_decode .metadata import SpecDecodeMetadata
@@ -156,6 +156,53 @@ def graph_capture(device: torch.device):
156
156
yield graph_capture_context
157
157
158
158
159
+ # Wrapper for ModelRunnerOutput to support overlapped execution.
160
+ class AsyncNPUModelRunnerOutput (AsyncModelRunnerOutput ):
161
+
162
+ def __init__ (
163
+ self ,
164
+ model_runner_output : ModelRunnerOutput ,
165
+ sampled_token_ids : torch .Tensor ,
166
+ invalid_req_indices : list [int ],
167
+ async_output_copy_stream : torch .npu .Stream ,
168
+ ):
169
+ self ._model_runner_output = model_runner_output
170
+ self ._invalid_req_indices = invalid_req_indices
171
+
172
+ # Event on the copy stream so we can synchronize the non-blocking copy.
173
+ self ._async_copy_ready_event = torch .npu .Event ()
174
+
175
+ # Keep a reference to the device tensor to avoid it being
176
+ # deallocated until we finish copying it to the host.
177
+ self ._sampled_token_ids = sampled_token_ids
178
+
179
+ # Initiate the copy on a separate stream, but do not synchronize it.
180
+ default_stream = torch .npu .current_stream ()
181
+ with torch .npu .stream (async_output_copy_stream ):
182
+ async_output_copy_stream .wait_stream (default_stream )
183
+ self ._sampled_token_ids_cpu = self ._sampled_token_ids .to (
184
+ 'cpu' , non_blocking = True )
185
+ self ._async_copy_ready_event .record ()
186
+
187
+ def get_output (self ) -> ModelRunnerOutput :
188
+ """Copy the device tensors to the host and return a ModelRunnerOutput.
189
+
190
+ This function blocks until the copy is finished.
191
+ """
192
+ self ._async_copy_ready_event .synchronize ()
193
+
194
+ # Release the device tensor once the copy has completed
195
+ del self ._sampled_token_ids
196
+
197
+ valid_sampled_token_ids = self ._sampled_token_ids_cpu .tolist ()
198
+ for i in self ._invalid_req_indices :
199
+ valid_sampled_token_ids [i ].clear ()
200
+
201
+ output = self ._model_runner_output
202
+ output .sampled_token_ids = valid_sampled_token_ids
203
+ return output
204
+
205
+
159
206
class NPUModelRunner (LoRAModelRunnerMixin ):
160
207
161
208
def __init__ (self , vllm_config : VllmConfig , device : torch .device ):
@@ -358,6 +405,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
358
405
device = self .device ,
359
406
)
360
407
408
+ self .use_async_scheduling = self .scheduler_config .async_scheduling
409
+ self .async_output_copy_stream = torch .npu .Stream () if \
410
+ self .use_async_scheduling else None
411
+
361
412
def _use_aclgraph (self ) -> bool :
362
413
return self .compilation_config .cudagraph_mode != CUDAGraphMode .NONE and self .compilation_config .level == CompilationLevel .PIECEWISE and not self .model_config .enforce_eager
363
414
@@ -845,6 +896,76 @@ def _get_cumsum_and_arange(
845
896
846
897
return cu_num_tokens , arange
847
898
899
+ def _prepare_input_ids (self , total_num_scheduled_tokens : int ,
900
+ cu_num_tokens : np .ndarray ) -> None :
901
+ """Prepare the input IDs for the current batch.
902
+
903
+ Carefully handles the `prev_sampled_token_ids` which can be cached
904
+ from the previous engine iteration, in which case those tokens on the
905
+ NPU need to be copied into the corresponding slots into input_ids."""
906
+
907
+ if self .input_batch .prev_sampled_token_ids is None :
908
+ # Normal scheduling case
909
+ self .input_ids [:total_num_scheduled_tokens ].copy_ (
910
+ self .input_ids_cpu [:total_num_scheduled_tokens ],
911
+ non_blocking = True )
912
+ return
913
+
914
+ # Async scheduling case, where some decode requests from the previous
915
+ # iteration won't have entries in input_ids_cpu and need to be copied
916
+ # on the NPU from prev_sampled_token_ids.
917
+ prev_req_id_to_index = self .input_batch .prev_req_id_to_index
918
+ assert prev_req_id_to_index is not None
919
+ flattened_indices = []
920
+ prev_common_req_indices = []
921
+ indices_match = True
922
+ max_flattened_index = - 1
923
+ for req_id , cur_index in self .input_batch .req_id_to_index .items ():
924
+ if (prev_index := prev_req_id_to_index .get (req_id )) is not None :
925
+ prev_common_req_indices .append (prev_index )
926
+ # We need to compute the flattened input_ids index of the
927
+ # last token in each common request.
928
+ flattened_index = cu_num_tokens [cur_index ].item () - 1
929
+ flattened_indices .append (flattened_index )
930
+ indices_match &= (prev_index == flattened_index )
931
+ max_flattened_index = max (max_flattened_index , flattened_index )
932
+ num_commmon_tokens = len (flattened_indices )
933
+ if num_commmon_tokens < total_num_scheduled_tokens :
934
+ # If not all requests are decodes from the last iteration,
935
+ # We need to copy the input_ids_cpu to the NPU first.
936
+ self .input_ids [:total_num_scheduled_tokens ].copy_ (
937
+ self .input_ids_cpu [:total_num_scheduled_tokens ],
938
+ non_blocking = True )
939
+ if num_commmon_tokens == 0 :
940
+ # No requests in common with the previous iteration
941
+ # So input_ids_cpu will have all the input ids.
942
+ return
943
+ if indices_match and max_flattened_index == (num_commmon_tokens - 1 ):
944
+ # Common-case optimization: the batch is unchanged
945
+ # and no reordering happened.
946
+ # The indices are both the same permutation of 0..N-1 so
947
+ # we can copy directly using a single slice.
948
+ self .input_ids [:num_commmon_tokens ].copy_ (
949
+ self .input_batch .prev_sampled_token_ids [:num_commmon_tokens ,
950
+ 0 ],
951
+ non_blocking = True )
952
+ return
953
+ # Upload the index tensors asynchronously
954
+ # so the scatter can be non-blocking.
955
+ input_ids_index_tensor = torch .tensor (flattened_indices ,
956
+ dtype = torch .int64 ,
957
+ pin_memory = self .pin_memory ).to (
958
+ self .device ,
959
+ non_blocking = True )
960
+ prev_common_req_indices_tensor = torch .tensor (
961
+ prev_common_req_indices ,
962
+ dtype = torch .int64 ,
963
+ pin_memory = self .pin_memory ).to (self .device , non_blocking = True )
964
+ self .input_ids .scatter_ (dim = 0 ,
965
+ index = input_ids_index_tensor ,
966
+ src = self .input_batch .prev_sampled_token_ids [
967
+ prev_common_req_indices_tensor , 0 ])
968
+
848
969
def _prepare_inputs (
849
970
self ,
850
971
scheduler_output : "SchedulerOutput" ,
@@ -1033,6 +1154,16 @@ def _prepare_inputs(
1033
1154
if self .vllm_config .model_config .use_mla :
1034
1155
attn_metadata .num_input_tokens = num_input_tokens
1035
1156
1157
+ # Prepare input_ids
1158
+ token_indices = (positions_np +
1159
+ req_indices * self .input_batch .token_ids_cpu .shape [1 ])
1160
+ torch .index_select (self .input_batch .token_ids_cpu_tensor .flatten (),
1161
+ 0 ,
1162
+ torch .from_numpy (token_indices ),
1163
+ out = self .input_ids_cpu [:total_num_scheduled_tokens ])
1164
+ # Copy the tensors to the NPU.
1165
+ self ._prepare_input_ids (total_num_scheduled_tokens , cu_num_tokens )
1166
+
1036
1167
# _prepare_inputs may reorder the batch, so we must gather
1037
1168
# multi-modal outputs after that to ensure the correct order
1038
1169
if self .is_multimodal_model :
@@ -1382,11 +1513,11 @@ def _select_moe_comm_method(self, num_tokens: int) -> str:
1382
1513
2. If expert parallel is enabled, we need to consider the soc version and the
1383
1514
number of tokens. This is based on the observation that all-gather is more
1384
1515
efficient than all-to-all when running on A2.
1385
-
1516
+
1386
1517
a. For A2, we choose from MC2 and all-gather.
1387
-
1518
+
1388
1519
b. For A3, we choose from MC2 and all-to-all.
1389
-
1520
+
1390
1521
In both cases, we use MC2 when the number of tokens is smaller than
1391
1522
a its capacity threshold.
1392
1523
@@ -1424,7 +1555,7 @@ def execute_model(
1424
1555
self ,
1425
1556
scheduler_output : "SchedulerOutput" ,
1426
1557
intermediate_tensors : Optional [IntermediateTensors ] = None ,
1427
- ) -> Union [ModelRunnerOutput , torch . Tensor ]:
1558
+ ) -> Union [ModelRunnerOutput , AsyncModelRunnerOutput , IntermediateTensors ]:
1428
1559
with ProfileExecuteDuration ().capture_async ("prepare input" ):
1429
1560
self ._update_states (scheduler_output )
1430
1561
if not scheduler_output .total_num_scheduled_tokens :
@@ -1580,6 +1711,12 @@ def execute_model(
1580
1711
generator .set_offset (generator .get_offset () - 4 )
1581
1712
discard_sampled_tokens_req_indices .append (i )
1582
1713
1714
+ # Copy some objects so they don't get modified after returning.
1715
+ # This is important when using async scheduling.
1716
+ req_ids_output_copy = self .input_batch .req_ids .copy ()
1717
+ req_id_to_index_output_copy = \
1718
+ self .input_batch .req_id_to_index .copy ()
1719
+
1583
1720
# NOTE: NPU -> CPU Sync happens here.
1584
1721
# Move as many CPU operations as possible before this sync point.
1585
1722
logprobs_tensors = sampler_output .logprobs_tensors
@@ -1592,27 +1729,52 @@ def execute_model(
1592
1729
scheduler_output ,
1593
1730
)
1594
1731
1595
- # Get the valid generated tokens.
1732
+ num_sampled_tokens = sampler_output . sampled_token_ids . shape [ 0 ]
1596
1733
sampled_token_ids = sampler_output .sampled_token_ids
1597
- max_gen_len = sampled_token_ids .shape [- 1 ]
1598
- if max_gen_len == 1 :
1599
- # No spec decode tokens.
1600
- valid_sampled_token_ids = sampled_token_ids .tolist ()
1734
+ if not self .use_async_scheduling :
1735
+ # Get the valid generated tokens.
1736
+ max_gen_len = sampled_token_ids .shape [- 1 ]
1737
+ if max_gen_len == 1 :
1738
+ # No spec decode tokens.
1739
+ valid_sampled_token_ids = sampled_token_ids .tolist ()
1740
+ else :
1741
+ # Includes spec decode tokens.
1742
+ valid_sampled_token_ids = self .rejection_sampler .parse_output (
1743
+ sampled_token_ids ,
1744
+ self .input_batch .vocab_size ,
1745
+ )
1746
+ # Mask out the sampled tokens that should not be sampled.
1747
+ for i in discard_sampled_tokens_req_indices :
1748
+ valid_sampled_token_ids [i ].clear ()
1601
1749
else :
1602
- # Includes spec decode tokens.
1603
- valid_sampled_token_ids = self .rejection_sampler .parse_output (
1604
- sampled_token_ids ,
1605
- self .input_batch .vocab_size ,
1606
- )
1607
-
1608
- for i in discard_sampled_tokens_req_indices :
1609
- valid_sampled_token_ids [i ].clear ()
1610
- # Cache the sampled tokens in the model runner, so that the schedulerAdd commentMore actions
1750
+ valid_sampled_token_ids = []
1751
+ invalid_req_indices = list (discard_sampled_tokens_req_indices )
1752
+ invalid_req_indices_set = set (invalid_req_indices )
1753
+ assert sampled_token_ids .shape [- 1 ] == 1
1754
+
1755
+ # Cache the sampled tokens on the NPU and avoid CPU sync.
1756
+ # These will be copied into input_ids in the next step
1757
+ # when preparing inputs.
1758
+ self .input_batch .prev_sampled_token_ids = \
1759
+ sampled_token_ids
1760
+ self .input_batch .prev_sampled_token_ids_invalid_indices = \
1761
+ invalid_req_indices_set
1762
+ self .input_batch .prev_req_id_to_index = {
1763
+ req_id : i
1764
+ for i , req_id in enumerate (self .input_batch .req_ids )
1765
+ if i not in invalid_req_indices_set
1766
+ }
1767
+ # Cache the sampled tokens in the model runner, so that the scheduler
1611
1768
# doesn't need to send them back.
1612
1769
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
1613
1770
# the sampled tokens back, because there's no direct communication
1614
1771
# between the first-stage worker and the last-stage worker.
1615
- for req_idx , sampled_ids in enumerate (valid_sampled_token_ids ):
1772
+ for req_idx in range (num_sampled_tokens ):
1773
+ if self .use_async_scheduling :
1774
+ sampled_ids = [- 1 ] * 1 if \
1775
+ req_idx not in invalid_req_indices_set else None
1776
+ else :
1777
+ sampled_ids = valid_sampled_token_ids [req_idx ]
1616
1778
if not sampled_ids :
1617
1779
continue
1618
1780
@@ -1650,8 +1812,8 @@ def execute_model(
1650
1812
extra_args = ({"kv_connector_output" : kv_connector_output })
1651
1813
1652
1814
model_runner_output = ModelRunnerOutput (
1653
- req_ids = self . input_batch . req_ids ,
1654
- req_id_to_index = self . input_batch . req_id_to_index ,
1815
+ req_ids = req_ids_output_copy ,
1816
+ req_id_to_index = req_id_to_index_output_copy ,
1655
1817
sampled_token_ids = valid_sampled_token_ids ,
1656
1818
logprobs = logprobs_lists ,
1657
1819
prompt_logprobs_dict = prompt_logprobs_dict ,
@@ -1669,7 +1831,15 @@ def execute_model(
1669
1831
logger .info ("Profile execute duration [%s]:%s" , captured_name ,
1670
1832
" " .join (dr_str ))
1671
1833
1672
- return model_runner_output
1834
+ if not self .use_async_scheduling :
1835
+ return model_runner_output
1836
+
1837
+ return AsyncNPUModelRunnerOutput (
1838
+ model_runner_output = model_runner_output ,
1839
+ sampled_token_ids = sampled_token_ids ,
1840
+ invalid_req_indices = invalid_req_indices ,
1841
+ async_output_copy_stream = self .async_output_copy_stream ,
1842
+ )
1673
1843
1674
1844
def take_draft_token_ids (self ) -> Optional [DraftTokenIds ]:
1675
1845
if self ._draft_token_ids is None :
0 commit comments