Skip to content

Commit b2b8ad4

Browse files
yewentao256Xu-Wenqing
authored andcommitted
[Feature] Enable DeepGEMM Linear on B200; 1.5% E2E throughput improvement (vllm-project#23351)
Signed-off-by: yewentao256 <[email protected]> Signed-off-by: root <[email protected]>
1 parent 100757f commit b2b8ad4

File tree

2 files changed

+13
-16
lines changed

2 files changed

+13
-16
lines changed

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
CUTLASS_BLOCK_FP8_SUPPORTED)
2020
from vllm.platforms import current_platform
2121
from vllm.triton_utils import tl, triton
22-
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
23-
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
22+
from vllm.utils import cdiv, direct_register_custom_op
23+
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
24+
should_use_deepgemm_for_fp8_linear)
2425

2526
logger = init_logger(__name__)
2627

@@ -108,19 +109,6 @@ def dispatch_w8a8_blockscale_func(
108109
return w8a8_block_fp8_matmul
109110

110111

111-
def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor):
112-
"""
113-
Check if DeepGEMM should be used based on the output dtype and weight shape.
114-
DeepGEMM is only supported for bfloat16 output dtype and weights with shape
115-
divisible by 128.
116-
"""
117-
118-
return (current_platform.is_cuda()
119-
and current_platform.is_device_capability(90) and has_deep_gemm()
120-
and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16
121-
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
122-
123-
124112
# TODO fix ROCm->Triton custom path:
125113
# https://github.com/vllm-project/vllm/issues/14397
126114
def apply_w8a8_block_fp8_linear(
@@ -139,7 +127,7 @@ def apply_w8a8_block_fp8_linear(
139127
output_shape = [*input.shape[:-1], weight.shape[0]]
140128
output_dtype = input.dtype
141129

142-
if should_use_deepgemm(output_dtype, weight):
130+
if should_use_deepgemm_for_fp8_linear(output_dtype, weight):
143131

144132
input_2d = input.view(-1, input.shape[-1])
145133
output_shape = [*input.shape[:-1], weight.shape[0]]
@@ -150,7 +138,9 @@ def apply_w8a8_block_fp8_linear(
150138
column_major_scales=True,
151139
)
152140

141+
# ensure DeepGEMM-backed custom op is registered before use
153142
import vllm.model_executor.layers.quantization.deepgemm # noqa: F401
143+
154144
output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm(
155145
q_input,
156146
weight,

vllm/utils/deep_gemm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,12 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
202202
return 1 - sim
203203

204204

205+
def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype,
206+
weight: torch.Tensor):
207+
return (is_deep_gemm_supported() and output_dtype == torch.bfloat16
208+
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
209+
210+
205211
__all__ = [
206212
"calc_diff",
207213
"fp8_gemm_nt",
@@ -210,4 +216,5 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
210216
"per_block_cast_to_fp8",
211217
"is_blackwell_deep_gemm_e8m0_used",
212218
"is_deep_gemm_supported",
219+
"should_use_deepgemm_for_fp8_linear",
213220
]

0 commit comments

Comments
 (0)