-
Notifications
You must be signed in to change notification settings - Fork 7
DBO HT without cudagraph #113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: sage/dbo-full-cudagraphs
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,10 +13,9 @@ | |
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable | ||
_resize_cache, count_expert_num_tokens) | ||
from vllm.utils import cdiv | ||
from vllm.v1.worker.ubatching import (dbo_enabled, | ||
dbo_yield, | ||
from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled, | ||
dbo_maybe_run_recv_hook, | ||
dbo_register_recv_hook) | ||
dbo_register_recv_hook, dbo_yield) | ||
|
||
# | ||
# This file defines a set of base classes used to make MoE kernels more modular. | ||
|
@@ -501,12 +500,14 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int, | |
|
||
|
||
class SharedResizableBuffer: | ||
|
||
def __init__(self): | ||
self.buffer = None | ||
|
||
# NOTE: Assumes the first call to get() is the largest shape, | ||
# this is usually true due to the profile run. | ||
def get(self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype): | ||
def get(self, shape: tuple[int, ...], device: torch.device, | ||
dtype: torch.dtype): | ||
shape_numel = prod(shape) | ||
if self.buffer is None or self.buffer.numel() < shape_numel: | ||
self.buffer = torch.empty(shape_numel, device=device, dtype=dtype) | ||
|
@@ -530,9 +531,19 @@ class FusedMoEModularKernel(torch.nn.Module): | |
layer due to any layer specific state that may be used by the component | ||
objects. | ||
""" | ||
fused_out_buffer = SharedResizableBuffer() | ||
workspace13_buffer = SharedResizableBuffer() | ||
workspace2_buffer = SharedResizableBuffer() | ||
# class-level ubatch buffers (0/1) | ||
_ubatch_buffers: list[dict[str, SharedResizableBuffer]] = [ | ||
{ | ||
"fused_out": SharedResizableBuffer(), | ||
"workspace13": SharedResizableBuffer(), | ||
"workspace2": SharedResizableBuffer(), | ||
}, | ||
{ | ||
"fused_out": SharedResizableBuffer(), | ||
"workspace13": SharedResizableBuffer(), | ||
"workspace2": SharedResizableBuffer(), | ||
}, | ||
] | ||
|
||
def __init__( | ||
self, | ||
|
@@ -581,16 +592,18 @@ def _do_fused_experts( | |
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, | ||
expert_tokens_meta) | ||
|
||
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO | ||
ubatch_idx = dbo_current_ubatch_id() | ||
buffers = self.__class__._ubatch_buffers[ubatch_idx] | ||
|
||
# We can reuse the memory between cache1 and cache3 because by the | ||
# time we need cache3, we're done with cache1. | ||
workspace13 = self.workspace13_buffer.get( | ||
workspace13_shape, | ||
device=a1.device, | ||
dtype=workspace_dtype) | ||
workspace2 = self.workspace2_buffer.get( | ||
workspace2_shape, | ||
device=a1.device, | ||
dtype=workspace_dtype) | ||
workspace13 = buffers["workspace13"].get(workspace13_shape, | ||
device=a1.device, | ||
dtype=workspace_dtype) | ||
workspace2 = buffers["workspace2"].get(workspace2_shape, | ||
device=a1.device, | ||
dtype=workspace_dtype) | ||
|
||
assert fused_out is None or fused_out.shape == fused_out_shape, ( | ||
f"fused_out {fused_out.shape} but expected {fused_out_shape}") | ||
|
@@ -682,10 +695,11 @@ def _maybe_chunk_fused_experts( | |
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes( | ||
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, | ||
expert_tokens_meta) | ||
fused_out = self.fused_out_buffer.get( | ||
fused_out_shape, | ||
device=a1q.device, | ||
dtype=a1.dtype) | ||
ubatch_idx = dbo_current_ubatch_id() | ||
buffers = self.__class__._ubatch_buffers[ubatch_idx] | ||
fused_out = buffers["fused_out"].get(fused_out_shape, | ||
device=a1q.device, | ||
dtype=a1.dtype) | ||
|
||
def slice_input_tensors( | ||
chunk_idx: int | ||
|
@@ -852,7 +866,7 @@ def forward( | |
else: | ||
# Overlap shared expert compute with all2all dispatch. | ||
dbo_maybe_run_recv_hook() | ||
hook, receiver = self.prepare_finalize.prepare_async( | ||
prepare_ret = self.prepare_finalize.prepare_async( | ||
a1, | ||
a1_scale, | ||
a2_scale, | ||
|
@@ -868,14 +882,22 @@ def forward( | |
if self.shared_experts is not None: | ||
shared_output = self.shared_experts(a1) | ||
|
||
dbo_register_recv_hook(hook) | ||
dbo_yield() | ||
|
||
if dbo_enabled(): | ||
if isinstance(prepare_ret, tuple): | ||
hook, receiver = prepare_ret | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how does this differ form the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think for the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess im curious why this needed since I though HT would go through the:
path There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean for HT, def supports_async(self) -> bool:
return True
then we need to be compatible of low latency, because it is returning return (hook, lambda hook: self._receiver(hook, expert_x, expert_num_tokens,
a1_scale, a1.dtype, quant_config)) |
||
hook = None | ||
|
||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, | ||
_expert_topk_weights) = receiver(hook) | ||
receiver = prepare_ret | ||
|
||
if hook is not None: | ||
dbo_register_recv_hook(hook) | ||
dbo_yield() | ||
if dbo_enabled(): | ||
hook = None | ||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, | ||
_expert_topk_weights) = receiver(hook) | ||
else: | ||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, | ||
_expert_topk_weights) = receiver() | ||
|
||
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks. | ||
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we restrict this to just when the batch is actually running DBO? or will this do that already?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we do it already?
dbo_enabled()
is one of the condition