13
13
from vllm .model_executor .layers .fused_moe .utils import ( # yapf: disable
14
14
_resize_cache , count_expert_num_tokens )
15
15
from vllm .utils import cdiv
16
- from vllm .v1 .worker .ubatching import (dbo_enabled , dbo_maybe_run_recv_hook ,
16
+ from vllm .v1 .worker .ubatching import (dbo_current_ubatch_id , dbo_enabled ,
17
+ dbo_maybe_run_recv_hook ,
17
18
dbo_register_recv_hook , dbo_yield )
18
19
19
20
#
@@ -530,9 +531,6 @@ class FusedMoEModularKernel(torch.nn.Module):
530
531
layer due to any layer specific state that may be used by the component
531
532
objects.
532
533
"""
533
- fused_out_buffer = SharedResizableBuffer ()
534
- workspace13_buffer = SharedResizableBuffer ()
535
- workspace2_buffer = SharedResizableBuffer ()
536
534
537
535
def __init__ (
538
536
self ,
@@ -550,6 +548,19 @@ def __init__(
550
548
f"{ prepare_finalize .activation_format } == "
551
549
f"{ fused_experts .__class__ .__name__ } ."
552
550
f"{ fused_experts .activation_formats [0 ]} " )
551
+ # Initialize double buffers for ubatch 0 and ubatch 1
552
+ self ._ubatch_buffers = [
553
+ {
554
+ "fused_out" : SharedResizableBuffer (),
555
+ "workspace13" : SharedResizableBuffer (),
556
+ "workspace2" : SharedResizableBuffer (),
557
+ },
558
+ {
559
+ "fused_out" : SharedResizableBuffer (),
560
+ "workspace13" : SharedResizableBuffer (),
561
+ "workspace2" : SharedResizableBuffer (),
562
+ },
563
+ ]
553
564
554
565
def _do_fused_experts (
555
566
self ,
@@ -581,14 +592,18 @@ def _do_fused_experts(
581
592
a1 , a1q , M , N , K , top_k , global_num_experts , local_num_experts ,
582
593
expert_tokens_meta )
583
594
595
+ # select per-ubatch buffers to avoid cross-ubatch reuse under DBO
596
+ ubatch_idx = dbo_current_ubatch_id ()
597
+ buffers = self ._ubatch_buffers [ubatch_idx ]
598
+
584
599
# We can reuse the memory between cache1 and cache3 because by the
585
600
# time we need cache3, we're done with cache1.
586
- workspace13 = self . workspace13_buffer .get (workspace13_shape ,
587
- device = a1 .device ,
588
- dtype = workspace_dtype )
589
- workspace2 = self . workspace2_buffer .get (workspace2_shape ,
590
- device = a1 .device ,
591
- dtype = workspace_dtype )
601
+ workspace13 = buffers [ "workspace13" ] .get (workspace13_shape ,
602
+ device = a1 .device ,
603
+ dtype = workspace_dtype )
604
+ workspace2 = buffers [ "workspace2" ] .get (workspace2_shape ,
605
+ device = a1 .device ,
606
+ dtype = workspace_dtype )
592
607
593
608
assert fused_out is None or fused_out .shape == fused_out_shape , (
594
609
f"fused_out { fused_out .shape } but expected { fused_out_shape } " )
@@ -680,9 +695,11 @@ def _maybe_chunk_fused_experts(
680
695
(_ , _ , fused_out_shape , _ ) = self .fused_experts .workspace_shapes (
681
696
a1 , a1q , M , N , K , top_k , global_num_experts , local_num_experts ,
682
697
expert_tokens_meta )
683
- fused_out = self .fused_out_buffer .get (fused_out_shape ,
684
- device = a1q .device ,
685
- dtype = a1 .dtype )
698
+ ubatch_idx = dbo_current_ubatch_id ()
699
+ buffers = self ._ubatch_buffers [ubatch_idx ]
700
+ fused_out = buffers ["fused_out" ].get (fused_out_shape ,
701
+ device = a1q .device ,
702
+ dtype = a1 .dtype )
686
703
687
704
def slice_input_tensors (
688
705
chunk_idx : int
0 commit comments