Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3727,15 +3727,18 @@ def __post_init__(self):
"Compilation level should be CompilationLevel.PIECEWISE "\
"when cudagraph_mode piecewise cudagraphs is used, "\
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"

if self.parallel_config.enable_microbatching:
a2a_backend = envs.VLLM_ALL2ALL_BACKEND
assert a2a_backend == "deepep_low_latency", \
"Microbatching currently only supports the deepep_low_latency "\

assert a2a_backend in ["deepep_low_latency", \
"deepep_high_throughput"], "Microbatching currently only supports "
"the deepep_low_latency and deepep_high_throughput "\
f"all2all backend. {a2a_backend} is not supported. To fix set "\
"the VLLM_ALL2ALL_BACKEND environment variable to "\
"deepep_low_latency and install the DeepEP kerenls."

"deepep_low_latency or deepep_high_throughput "
"and install the DeepEP kerenls."

if not self.instance_id:
self.instance_id = random_uuid()[:5]

Expand Down
14 changes: 13 additions & 1 deletion vllm/distributed/device_communicators/all2all.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils import has_deep_ep, has_pplx
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.v1.worker.ubatching import dbo_enabled

from .base_device_communicator import All2AllManagerBase, Cache

Expand Down Expand Up @@ -196,6 +197,17 @@ def get_handle(self, kwargs):
# situation where we make objects with different num_sms, the hash key
# in get_or_create must be updated.
handle.set_num_sms(self.num_sms)
# configure DeepGEMM to use the remaining SMs for compute.
# This avoids contention with communication
if has_deep_gemm() and dbo_enabled():
import deep_gemm as dg
props = torch.cuda.get_device_properties(
torch.cuda.current_device())
total_sms = props.multi_processor_count
compute_sms = total_sms - self.num_sms
assert compute_sms > 0, "compute_sms must be greater than 0"
logger.info("Setting DeepGEMM num_sms to %d for dbo", compute_sms)
dg.set_num_sms(compute_sms)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we restrict this to just when the batch is actually running DBO? or will this do that already?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we do it already? dbo_enabled() is one of the condition

return handle


Expand Down
27 changes: 20 additions & 7 deletions vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.v1.worker.ubatching import (
dbo_current_ubatch_id, dbo_yield_and_switch_from_comm_to_compute,
dbo_yield_and_switch_from_compute_to_comm)


class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
Expand All @@ -28,9 +31,9 @@ def __init__(self, buffer: deep_ep.Buffer, num_dispatchers: int,
self.async_prepare = True

# The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the
# combine function.
self.handle = None
# requires. Under DBO microbatching we must track one handle per
# micro-batch to avoid races between threads.
self.handles = [None, None]

# From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164
self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160]
Expand Down Expand Up @@ -71,6 +74,8 @@ def _do_dispatch(

has_scales = token_scales is not None

dbo_yield_and_switch_from_compute_to_comm()

(num_tokens_per_rank, num_tokens_per_rdma_rank,
dispatch_expert_num_tokens, is_token_in_rank,
event) = self.buffer.get_dispatch_layout(
Expand All @@ -86,7 +91,7 @@ def _do_dispatch(

(
token_data, expert_topk_ids, expert_topk_weights,
expert_num_tokens_per_expert_list, self.handle, event
expert_num_tokens_per_expert_list, handle, event
) = self.buffer.dispatch(
x=token_data,
handle=None,
Expand All @@ -103,6 +108,11 @@ def _do_dispatch(
previous_event=None,
async_finish=self.async_prepare,
allocate_on_comm_stream=False)
dbo_yield_and_switch_from_comm_to_compute()

# record the handle for this ubatch
a2a_idx = dbo_current_ubatch_id()
self.handles[a2a_idx] = handle

return lambda: self._receiver(
event,
Expand Down Expand Up @@ -253,7 +263,9 @@ def finalize(
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:

assert self.handle is not None
a2a_idx = dbo_current_ubatch_id()
handle = self.handles[a2a_idx]
assert handle is not None

# fused_expert_output can have 0 tokens - This happens when none of the
# tokens from the all2all reach this EP rank.
Expand All @@ -267,14 +279,15 @@ def finalize(
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input,
)

dbo_yield_and_switch_from_compute_to_comm()
combined_x, _, event = self.buffer.combine(
x=fused_expert_output,
handle=self.handle,
handle=handle,
topk_weights=None,
config=self._get_combine_config(),
previous_event=None,
async_finish=False,
allocate_on_comm_stream=False)
dbo_yield_and_switch_from_comm_to_compute()
# Respect inplace outputs.
output.copy_(combined_x, non_blocking=True)
78 changes: 50 additions & 28 deletions vllm/model_executor/layers/fused_moe/modular_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
_resize_cache, count_expert_num_tokens)
from vllm.utils import cdiv
from vllm.v1.worker.ubatching import (dbo_enabled,
dbo_yield,
from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
dbo_maybe_run_recv_hook,
dbo_register_recv_hook)
dbo_register_recv_hook, dbo_yield)

#
# This file defines a set of base classes used to make MoE kernels more modular.
Expand Down Expand Up @@ -501,12 +500,14 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int,


class SharedResizableBuffer:

def __init__(self):
self.buffer = None

# NOTE: Assumes the first call to get() is the largest shape,
# this is usually true due to the profile run.
def get(self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype):
def get(self, shape: tuple[int, ...], device: torch.device,
dtype: torch.dtype):
shape_numel = prod(shape)
if self.buffer is None or self.buffer.numel() < shape_numel:
self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
Expand All @@ -530,9 +531,19 @@ class FusedMoEModularKernel(torch.nn.Module):
layer due to any layer specific state that may be used by the component
objects.
"""
fused_out_buffer = SharedResizableBuffer()
workspace13_buffer = SharedResizableBuffer()
workspace2_buffer = SharedResizableBuffer()
# class-level ubatch buffers (0/1)
_ubatch_buffers: list[dict[str, SharedResizableBuffer]] = [
{
"fused_out": SharedResizableBuffer(),
"workspace13": SharedResizableBuffer(),
"workspace2": SharedResizableBuffer(),
},
{
"fused_out": SharedResizableBuffer(),
"workspace13": SharedResizableBuffer(),
"workspace2": SharedResizableBuffer(),
},
]

def __init__(
self,
Expand Down Expand Up @@ -581,16 +592,18 @@ def _do_fused_experts(
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
expert_tokens_meta)

# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
ubatch_idx = dbo_current_ubatch_id()
buffers = self.__class__._ubatch_buffers[ubatch_idx]

# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
workspace13 = self.workspace13_buffer.get(
workspace13_shape,
device=a1.device,
dtype=workspace_dtype)
workspace2 = self.workspace2_buffer.get(
workspace2_shape,
device=a1.device,
dtype=workspace_dtype)
workspace13 = buffers["workspace13"].get(workspace13_shape,
device=a1.device,
dtype=workspace_dtype)
workspace2 = buffers["workspace2"].get(workspace2_shape,
device=a1.device,
dtype=workspace_dtype)

assert fused_out is None or fused_out.shape == fused_out_shape, (
f"fused_out {fused_out.shape} but expected {fused_out_shape}")
Expand Down Expand Up @@ -682,10 +695,11 @@ def _maybe_chunk_fused_experts(
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
expert_tokens_meta)
fused_out = self.fused_out_buffer.get(
fused_out_shape,
device=a1q.device,
dtype=a1.dtype)
ubatch_idx = dbo_current_ubatch_id()
buffers = self.__class__._ubatch_buffers[ubatch_idx]
fused_out = buffers["fused_out"].get(fused_out_shape,
device=a1q.device,
dtype=a1.dtype)

def slice_input_tensors(
chunk_idx: int
Expand Down Expand Up @@ -852,7 +866,7 @@ def forward(
else:
# Overlap shared expert compute with all2all dispatch.
dbo_maybe_run_recv_hook()
hook, receiver = self.prepare_finalize.prepare_async(
prepare_ret = self.prepare_finalize.prepare_async(
a1,
a1_scale,
a2_scale,
Expand All @@ -868,14 +882,22 @@ def forward(
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)

dbo_register_recv_hook(hook)
dbo_yield()

if dbo_enabled():
if isinstance(prepare_ret, tuple):
hook, receiver = prepare_ret
else:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how does this differ form the if not self.prepare_finalize.supports_async(): path?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for the self.prepare_finalize.prepare() path, receiver will be called first and then packed to (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, _expert_topk_weights). So we don't need to update it?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess im curious why this needed since I though HT would go through the:

            if self.shared_experts is not None:
                shared_output = self.shared_experts(a1)

            (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
             _expert_topk_weights) = self.prepare_finalize.prepare(
                 a1,
                 a1_scale,
                 a2_scale,
                 topk_weights,
                 topk_ids,
                 global_num_experts,
                 expert_map,
                 apply_router_weight_on_input,
                 self.fused_experts.quant_config,
             )

path

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean for HT, supports_async should be False instead of True?

    def supports_async(self) -> bool:
        return True

supports_async equals to True currently will let us go the the branch using self.prepare_finalize.prepare_async(

then we need to be compatible of low latency, because it is returning

return (hook, lambda hook: self._receiver(hook, expert_x, expert_num_tokens,
                                      a1_scale, a1.dtype, quant_config))

hook = None

(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = receiver(hook)
receiver = prepare_ret

if hook is not None:
dbo_register_recv_hook(hook)
dbo_yield()
if dbo_enabled():
hook = None
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = receiver(hook)
else:
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = receiver()

# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
Expand Down