Skip to content
Merged
6 changes: 3 additions & 3 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16)
CompressedTensorsWNA16, cutlass_fp4_supported)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
sparse_cutlass_supported)
from vllm.platforms import current_platform
Expand Down Expand Up @@ -668,8 +668,8 @@ def check_model(model):
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
if isinstance(qkv_proj.scheme, scheme) or isinstance(
qkv_proj.scheme, CompressedTensorsW4A16Fp4
) and not CompressedTensorsW4A4Fp4.cutlass_fp4_supported():
qkv_proj.scheme,
CompressedTensorsW4A16Fp4) and not cutlass_fp4_supported():
assert True
else:
raise AssertionError("FP4 Scheme Mismatch")
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,7 @@ def weight_loader(self,
param.materialize(final_shape, dtype=loaded_weight.dtype)

expert_data = param.data if full_load else param.data[expert_id]

# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# this is needed for compressed-tensors only
Expand Down Expand Up @@ -1277,6 +1278,7 @@ def weight_loader(self,
tp_rank=self.tp_rank)
return True if return_success else None

# TODO @dsikka: ModelOpt should follow the proper MoE loading pattern
if "ModelOpt" in quant_method_name:
if ('weight_scale_2' in weight_name
or 'input_scale' in weight_name):
Expand All @@ -1293,7 +1295,7 @@ def weight_loader(self,
tp_rank=self.tp_rank)
return True if return_success else None

# Case weight scales, zero_points and offset
# Case weight scales, zero_points and offset, weight/input global scales
if ("scale" in weight_name or "zero" in weight_name
or "offset" in weight_name):
# load the weight scales and zp based on the quantization scheme
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
find_matched_target, is_activation_quantization_format,
should_ignore_layer)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
cutlass_fp4_supported)
from vllm.platforms import current_platform

logger = init_logger(__name__)
Expand Down Expand Up @@ -375,7 +377,7 @@ def _get_scheme_from_parts(

if is_activation_quantization_format(self.quant_format):
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
if CompressedTensorsW4A4Fp4.cutlass_fp4_supported(
if cutlass_fp4_supported(
) or envs.VLLM_USE_NVFP4_CT_EMULATIONS:
return CompressedTensorsW4A4Fp4()
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_moe_marlin_supports_layer, marlin_make_workspace_new,
marlin_moe_permute_scales)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_moe_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
cutlass_fp4_supported)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs
Expand All @@ -48,12 +52,11 @@ class GPTQMarlinState(Enum):


__all__ = [
"CompressedTensorsMoEMethod",
"CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsW8A8Fp8MoECutlassMethod",
"CompressedTensorsW8A8Int8MoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod",
"CompressedTensorsW4A4MoeMethod"
]


Expand Down Expand Up @@ -86,6 +89,8 @@ def get_moe_method(
else:
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4MoeMethod()
elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
Expand All @@ -97,6 +102,268 @@ def get_moe_method(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")


class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):

def __init__(self):
self.use_marlin = not cutlass_fp4_supported()
self.group_size = 16

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):

layer.num_experts = num_experts
layer.params_dtype = params_dtype

w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // 2,
requires_grad=False,
dtype=torch.uint8),
requires_grad=False)
layer.register_parameter("w13_weight_packed", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)

w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // 2,
dtype=torch.uint8),
requires_grad=False)
layer.register_parameter("w2_weight_packed", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)

# Weight Scales
w13_weight_scale = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // self.group_size,
dtype=torch.float8_e4m3fn),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)

w2_weight_scale = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // self.group_size,
dtype=torch.float8_e4m3fn),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value})
set_weight_attrs(w2_weight_scale, extra_weight_attrs)

# Weight Global Scales
w13_weight_scale_2 = torch.nn.Parameter(torch.empty(
num_experts, 2, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w13_weight_scale_2, extra_weight_attrs)

w2_weight_scale_2 = torch.nn.Parameter(torch.empty(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w2_weight_scale_2, extra_weight_attrs)

# Input Global Scales
w13_input_scale = torch.nn.Parameter(torch.empty(num_experts,
2,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_input_global_scale", w13_input_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w13_input_scale, extra_weight_attrs)

w2_input_scale = torch.nn.Parameter(torch.empty(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_input_global_scale", w2_input_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w2_input_scale, extra_weight_attrs)

def swizzle_blockscale(self, scale: torch.tensor):
assert (scale.dtype == torch.float8_e4m3fn)
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
if scale.ndim == 2:
scale = scale.unsqueeze(0)
assert scale.ndim == 3
B, M, K = scale.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
padded_scale[:B, :M, :K] = scale
batches, rows, cols = padded_scale.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
return (swizzled_scale.reshape(M, K)
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

# From packed to weight
layer.w13_weight = torch.nn.Parameter(layer.w13_weight_packed.data,
requires_grad=False)

layer.w2_weight = torch.nn.Parameter(layer.w2_weight_packed.data,
requires_grad=False)

if not torch.allclose(layer.w13_weight_global_scale[:, 0],
layer.w13_weight_global_scale[:, 1]):
logger.warning_once(
"w1_weight_global_scale must match w3_weight_global_scale. "
"Accuracy may be affected.")

# Take inverse of global scale saved to disk
layer.w13_weight_scale_2 = torch.nn.Parameter(
1 / layer.w13_weight_global_scale[:, 0], requires_grad=False)

layer.w2_weight_scale_2 = torch.nn.Parameter(
1 / layer.w2_weight_global_scale.data, requires_grad=False)

if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer)
return

# swizzle weight scales
layer.w13_blockscale_swizzled = torch.nn.Parameter(
self.swizzle_blockscale(layer.w13_weight_scale),
requires_grad=False)

layer.w2_blockscale_swizzled = torch.nn.Parameter(
self.swizzle_blockscale(layer.w2_weight_scale),
requires_grad=False)

# w13
w13_input_global_scale = layer.w13_input_global_scale.max(
dim=1).values.to(torch.float32)

layer.g1_alphas = torch.nn.Parameter(
((1 / w13_input_global_scale) * layer.w13_weight_scale_2),
requires_grad=False)

layer.w13_input_scale_quant = torch.nn.Parameter(
(w13_input_global_scale), requires_grad=False)

# w2
layer.g2_alphas = torch.nn.Parameter(
((1 / layer.w2_input_global_scale) * layer.w2_weight_scale_2).to(
torch.float32),
requires_grad=False)

layer.w2_input_scale_quant = torch.nn.Parameter(
(layer.w2_input_global_scale), requires_grad=False)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError("EPLB not supported for "
"`CompressedTensorsW4A4MoeMethod` yet.")

topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)

if self.use_marlin:
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
global_scale1=layer.w13_weight_scale_2,
global_scale2=layer.w2_weight_scale_2,
quant_type_id=scalar_types.float4_e2m1f.id,
global_num_experts=global_num_experts,
expert_map=expert_map)

assert activation == "silu", "Only SiLU activation is supported."
assert not apply_router_weight_on_input, (
"Router weight on input is not "
"supported for CompressedTensorsW4A4MoeMethod.")
assert expert_map is None, ("Expert Parallelism / expert_map "
"is currently not supported for "
"CompressedTensorsW4A4MoeMethod.")

from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp4)

# Cutlass moe takes in activations in BF16/Half precision
# and fp4 quantized weights loaded from the checkpoint
return cutlass_moe_fp4(a=x,
w1_fp4=layer.w13_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w1_alphas=layer.g1_alphas,
w2_fp4=layer.w2_weight,
w2_blockscale=layer.w2_blockscale_swizzled,
w2_alphas=layer.g2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=x.shape[0],
n=layer.w2_weight.shape[2] * 2,
k=x.shape[1],
e=layer.w13_weight.shape[0],
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
device=x.device).to(x.dtype)


class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):

def __init__(
Expand Down
Loading