19
19
CUTLASS_BLOCK_FP8_SUPPORTED )
20
20
from vllm .platforms import current_platform
21
21
from vllm .triton_utils import tl , triton
22
- from vllm .utils import cdiv , direct_register_custom_op , has_deep_gemm
23
- from vllm .utils .deep_gemm import is_blackwell_deep_gemm_e8m0_used
22
+ from vllm .utils import cdiv , direct_register_custom_op
23
+ from vllm .utils .deep_gemm import (is_blackwell_deep_gemm_e8m0_used ,
24
+ should_use_deepgemm_for_fp8_linear )
24
25
25
26
logger = init_logger (__name__ )
26
27
@@ -108,19 +109,6 @@ def dispatch_w8a8_blockscale_func(
108
109
return w8a8_block_fp8_matmul
109
110
110
111
111
- def should_use_deepgemm (output_dtype : torch .dtype , weight : torch .Tensor ):
112
- """
113
- Check if DeepGEMM should be used based on the output dtype and weight shape.
114
- DeepGEMM is only supported for bfloat16 output dtype and weights with shape
115
- divisible by 128.
116
- """
117
-
118
- return (current_platform .is_cuda ()
119
- and current_platform .is_device_capability (90 ) and has_deep_gemm ()
120
- and envs .VLLM_USE_DEEP_GEMM and output_dtype == torch .bfloat16
121
- and weight .shape [0 ] % 128 == 0 and weight .shape [1 ] % 128 == 0 )
122
-
123
-
124
112
# TODO fix ROCm->Triton custom path:
125
113
# https://github.com/vllm-project/vllm/issues/14397
126
114
def apply_w8a8_block_fp8_linear (
@@ -139,7 +127,7 @@ def apply_w8a8_block_fp8_linear(
139
127
output_shape = [* input .shape [:- 1 ], weight .shape [0 ]]
140
128
output_dtype = input .dtype
141
129
142
- if should_use_deepgemm (output_dtype , weight ):
130
+ if should_use_deepgemm_for_fp8_linear (output_dtype , weight ):
143
131
144
132
input_2d = input .view (- 1 , input .shape [- 1 ])
145
133
output_shape = [* input .shape [:- 1 ], weight .shape [0 ]]
@@ -150,7 +138,9 @@ def apply_w8a8_block_fp8_linear(
150
138
column_major_scales = True ,
151
139
)
152
140
141
+ # ensure DeepGEMM-backed custom op is registered before use
153
142
import vllm .model_executor .layers .quantization .deepgemm # noqa: F401
143
+
154
144
output = torch .ops .vllm .w8a8_block_fp8_matmul_deepgemm (
155
145
q_input ,
156
146
weight ,
0 commit comments