Skip to content

Commit 3724107

Browse files
[Misc] Removed force_fp8_e4m3fnuz from FP8LinearOp (#23725)
Signed-off-by: Julien Lin <[email protected]> Signed-off-by: Luka Govedič <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
1 parent c9f7081 commit 3724107

File tree

5 files changed

+45
-30
lines changed

5 files changed

+45
-30
lines changed

tests/compile/test_fusion.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1616
GroupShape, QuantKey, ScaleDesc)
1717
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
18-
Fp8LinearOp, maybe_create_device_identity)
18+
Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity)
1919
from vllm.platforms import current_platform
2020

21+
from ..utils import override_cutlass_fp8_supported
2122
from .backend import TestBackend
2223

2324
FP8_DTYPE = current_platform.fp8_dtype()
@@ -26,9 +27,9 @@
2627
class TestModel(torch.nn.Module):
2728

2829
def __init__(self, hidden_size: int, eps: float, static: bool,
29-
force_fp8_e4m3fnuz: bool, *args, **kwargs):
30+
cuda_force_torch: bool, *args, **kwargs):
3031
super().__init__(*args, **kwargs)
31-
self.force_fp8_e4m3fnuz = force_fp8_e4m3fnuz
32+
self.cuda_force_torch = cuda_force_torch
3233
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
3334
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
3435
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
@@ -42,11 +43,12 @@ def __init__(self, hidden_size: int, eps: float, static: bool,
4243
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
4344
for _ in range(2)
4445
]
45-
self.fp8_linear = Fp8LinearOp(
46-
force_fp8_e4m3fnuz=force_fp8_e4m3fnuz,
47-
act_quant_static=static,
48-
act_quant_group_shape=group_shape,
49-
)
46+
47+
with override_cutlass_fp8_supported(not cuda_force_torch):
48+
self.fp8_linear = Fp8LinearOp(
49+
act_quant_static=static,
50+
act_quant_group_shape=group_shape,
51+
)
5052

5153
def forward(self, x):
5254
resid = torch.sqrt(x)
@@ -81,11 +83,14 @@ def ops_in_model_after(self):
8183
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
8284
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
8385
@pytest.mark.parametrize("static", [True, False])
84-
@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False])
86+
# cuda_force_torch used to test torch code path on platforms that
87+
# cutlass_fp8_supported() == True.
88+
@pytest.mark.parametrize("cuda_force_torch",
89+
[True, False] if cutlass_fp8_supported() else [True])
8590
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
8691
reason="Only test on CUDA and ROCm")
8792
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
88-
force_fp8_e4m3fnuz):
93+
cuda_force_torch):
8994
torch.set_default_device("cuda")
9095
torch.set_default_dtype(dtype)
9196
torch.manual_seed(1)
@@ -102,7 +107,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
102107
fusion_pass = FusionPass.instance(vllm_config)
103108

104109
backend = TestBackend(noop_pass, fusion_pass)
105-
model = TestModel(hidden_size, eps, static, force_fp8_e4m3fnuz)
110+
model = TestModel(hidden_size, eps, static, cuda_force_torch)
106111

107112
# First dimension dynamic
108113
x = torch.rand(num_tokens, hidden_size)

tests/compile/test_silu_mul_quant_fusion.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1818
GroupShape, kFp8StaticTensorSym, kNvfp4Quant)
1919
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
20-
Fp8LinearOp)
20+
Fp8LinearOp, cutlass_fp8_supported)
2121
from vllm.platforms import current_platform
2222

23+
from ..utils import override_cutlass_fp8_supported
2324
from .backend import TestBackend
2425

2526
FP8_DTYPE = current_platform.fp8_dtype()
@@ -32,19 +33,19 @@ def is_nvfp4_supported():
3233

3334
class TestSiluMulFp8QuantModel(torch.nn.Module):
3435

35-
def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, **kwargs):
36+
def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
3637
super().__init__()
3738
self.silu_and_mul = SiluAndMul()
3839
self.wscale = torch.rand(1, dtype=torch.float32)
3940
self.scale = torch.rand(1, dtype=torch.float32)
4041

4142
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
4243

43-
self.fp8_linear = Fp8LinearOp(
44-
force_fp8_e4m3fnuz=force_fp8_e4m3fnuz,
45-
act_quant_static=True,
46-
act_quant_group_shape=GroupShape.PER_TENSOR,
47-
)
44+
with override_cutlass_fp8_supported(not cuda_force_torch):
45+
self.fp8_linear = Fp8LinearOp(
46+
act_quant_static=True,
47+
act_quant_group_shape=GroupShape.PER_TENSOR,
48+
)
4849

4950
def forward(self, x):
5051
y = self.silu_and_mul(x)
@@ -96,12 +97,15 @@ def ops_in_model_after(self):
9697
@pytest.mark.parametrize(
9798
"model_class", [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
9899
if is_nvfp4_supported() else [TestSiluMulFp8QuantModel])
99-
@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False])
100+
# cuda_force_torch used to test torch code path on platforms that
101+
# cutlass_fp8_supported() == True.
102+
@pytest.mark.parametrize("cuda_force_torch",
103+
[True, False] if cutlass_fp8_supported() else [True])
100104
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
101105
reason="Only test on CUDA and ROCm")
102106
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
103-
force_fp8_e4m3fnuz):
104-
if model_class == TestSiluMulNvfp4QuantModel and force_fp8_e4m3fnuz:
107+
cuda_force_torch):
108+
if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
105109
pytest.skip("Duplicate tests for NVFP4")
106110

107111
torch.set_default_device("cuda")
@@ -114,8 +118,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
114118
fusion_pass = ActivationQuantFusionPass(config)
115119

116120
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
117-
model = model_class(hidden_size=hidden_size,
118-
force_fp8_e4m3fnuz=force_fp8_e4m3fnuz)
121+
model = model_class(hidden_size, cuda_force_torch)
119122

120123
# First dimension dynamic
121124
x = torch.rand(num_tokens, hidden_size * 2)

tests/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from multiprocessing import Process
1818
from pathlib import Path
1919
from typing import Any, Callable, Literal, Optional, Union
20+
from unittest.mock import patch
2021

2122
import cloudpickle
2223
import httpx
@@ -1077,3 +1078,11 @@ def get_attn_backend_list_based_on_platform() -> list[str]:
10771078
return attn_backend_list
10781079
else:
10791080
raise ValueError("Unsupported platform")
1081+
1082+
1083+
@contextmanager
1084+
def override_cutlass_fp8_supported(value: bool):
1085+
with patch(
1086+
"vllm.model_executor.layers.quantization.utils.w8a8_utils.cutlass_fp8_supported",
1087+
return_value=value):
1088+
yield

vllm/model_executor/layers/quantization/ptpc_fp8.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,13 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
9292
"""
9393

9494
def __init__(self, quant_config: PTPCFp8Config):
95+
assert current_platform.is_rocm(), \
96+
"PTPCFp8LinearMethod is only supported on ROCm."
9597
super().__init__(quant_config=quant_config)
9698
# Force weight quantization
9799
self.quant_config.is_checkpoint_fp8_serialized = False
98100
self.fp8_linear = Fp8LinearOp(
99-
act_quant_static=False,
100-
act_quant_group_shape=GroupShape.PER_TOKEN,
101-
force_fp8_e4m3fnuz=True)
101+
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN)
102102

103103
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
104104
layer.weight = torch.nn.Parameter(layer.weight.data,

vllm/model_executor/layers/quantization/utils/w8a8_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -355,12 +355,10 @@ class Fp8LinearOp:
355355
def __init__(self,
356356
act_quant_static: bool,
357357
act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR,
358-
pad_output: Optional[bool] = None,
359-
force_fp8_e4m3fnuz: bool = False):
358+
pad_output: Optional[bool] = None):
360359
if current_platform.is_rocm():
361360
self.preferred_backend = "rocm"
362-
elif current_platform.is_cuda(
363-
) and not force_fp8_e4m3fnuz and cutlass_fp8_supported():
361+
elif current_platform.is_cuda() and cutlass_fp8_supported():
364362
if has_flashinfer() and current_platform.has_device_capability(
365363
100):
366364
self.preferred_backend = "flashinfer"

0 commit comments

Comments
 (0)