Skip to content

Commit ac8fbb7

Browse files
committed
fix acc issue
1 parent 278d727 commit ac8fbb7

File tree

1 file changed

+30
-13
lines changed

1 file changed

+30
-13
lines changed

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
1414
_resize_cache, count_expert_num_tokens)
1515
from vllm.utils import cdiv
16-
from vllm.v1.worker.ubatching import (dbo_enabled, dbo_maybe_run_recv_hook,
16+
from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
17+
dbo_maybe_run_recv_hook,
1718
dbo_register_recv_hook, dbo_yield)
1819

1920
#
@@ -530,9 +531,6 @@ class FusedMoEModularKernel(torch.nn.Module):
530531
layer due to any layer specific state that may be used by the component
531532
objects.
532533
"""
533-
fused_out_buffer = SharedResizableBuffer()
534-
workspace13_buffer = SharedResizableBuffer()
535-
workspace2_buffer = SharedResizableBuffer()
536534

537535
def __init__(
538536
self,
@@ -550,6 +548,19 @@ def __init__(
550548
f"{prepare_finalize.activation_format} == "
551549
f"{fused_experts.__class__.__name__}."
552550
f"{fused_experts.activation_formats[0]}")
551+
# Initialize double buffers for ubatch 0 and ubatch 1
552+
self._ubatch_buffers = [
553+
{
554+
"fused_out": SharedResizableBuffer(),
555+
"workspace13": SharedResizableBuffer(),
556+
"workspace2": SharedResizableBuffer(),
557+
},
558+
{
559+
"fused_out": SharedResizableBuffer(),
560+
"workspace13": SharedResizableBuffer(),
561+
"workspace2": SharedResizableBuffer(),
562+
},
563+
]
553564

554565
def _do_fused_experts(
555566
self,
@@ -581,14 +592,18 @@ def _do_fused_experts(
581592
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
582593
expert_tokens_meta)
583594

595+
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
596+
ubatch_idx = dbo_current_ubatch_id()
597+
buffers = self._ubatch_buffers[ubatch_idx]
598+
584599
# We can reuse the memory between cache1 and cache3 because by the
585600
# time we need cache3, we're done with cache1.
586-
workspace13 = self.workspace13_buffer.get(workspace13_shape,
587-
device=a1.device,
588-
dtype=workspace_dtype)
589-
workspace2 = self.workspace2_buffer.get(workspace2_shape,
590-
device=a1.device,
591-
dtype=workspace_dtype)
601+
workspace13 = buffers["workspace13"].get(workspace13_shape,
602+
device=a1.device,
603+
dtype=workspace_dtype)
604+
workspace2 = buffers["workspace2"].get(workspace2_shape,
605+
device=a1.device,
606+
dtype=workspace_dtype)
592607

593608
assert fused_out is None or fused_out.shape == fused_out_shape, (
594609
f"fused_out {fused_out.shape} but expected {fused_out_shape}")
@@ -680,9 +695,11 @@ def _maybe_chunk_fused_experts(
680695
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
681696
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
682697
expert_tokens_meta)
683-
fused_out = self.fused_out_buffer.get(fused_out_shape,
684-
device=a1q.device,
685-
dtype=a1.dtype)
698+
ubatch_idx = dbo_current_ubatch_id()
699+
buffers = self._ubatch_buffers[ubatch_idx]
700+
fused_out = buffers["fused_out"].get(fused_out_shape,
701+
device=a1q.device,
702+
dtype=a1.dtype)
686703

687704
def slice_input_tensors(
688705
chunk_idx: int

0 commit comments

Comments
 (0)