Skip to content

Commit 8f0a9ca

Browse files
authored
[Bugfix] Respect modules_to_not_convert within awq_marlin (#9895)
Signed-off-by: mgoin <[email protected]>
1 parent 2094062 commit 8f0a9ca

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

vllm/model_executor/layers/quantization/awq_marlin.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
from vllm.model_executor.layers.fused_moe.layer import (
1010
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
1111
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
12+
UnquantizedLinearMethod,
1213
set_weight_attrs)
14+
from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq
1315
from vllm.model_executor.layers.quantization.base_config import (
1416
QuantizationConfig, QuantizeMethodBase)
1517
from vllm.model_executor.layers.quantization.utils import replace_parameter
@@ -36,13 +38,18 @@ class AWQMarlinConfig(QuantizationConfig):
3638
8: scalar_types.uint8,
3739
}
3840

39-
def __init__(self, weight_bits: int, group_size: int, has_zp: bool,
40-
lm_head_quantized: bool) -> None:
41+
def __init__(self,
42+
weight_bits: int,
43+
group_size: int,
44+
zero_point: bool,
45+
lm_head_quantized: bool,
46+
modules_to_not_convert: Optional[List[str]] = None) -> None:
4147
self.pack_factor = 32 // weight_bits # packed into int32
4248
self.group_size = group_size
43-
self.has_zp = has_zp
49+
self.zero_point = zero_point
4450
self.lm_head_quantized = lm_head_quantized
4551
self.weight_bits = weight_bits
52+
self.modules_to_not_convert = modules_to_not_convert or []
4653

4754
if self.weight_bits not in self.TYPE_MAP:
4855
raise ValueError(f"Unsupported num_bits = {self.weight_bits}. "
@@ -52,13 +59,14 @@ def __init__(self, weight_bits: int, group_size: int, has_zp: bool,
5259

5360
verify_marlin_supported(self.quant_type,
5461
group_size=self.group_size,
55-
has_zp=self.has_zp)
62+
has_zp=self.zero_point)
5663

5764
def __repr__(self) -> str:
5865
return (f"AWQMarlinConfig(quant_type={self.quant_type}, "
5966
f"group_size={self.group_size}, "
60-
f"has_zp={self.has_zp}, "
61-
f"lm_head_quantized={self.lm_head_quantized})")
67+
f"zero_point={self.zero_point}, "
68+
f"lm_head_quantized={self.lm_head_quantized}, "
69+
f"modules_to_not_convert={self.modules_to_not_convert})")
6270

6371
@classmethod
6472
def get_name(cls) -> str:
@@ -80,10 +88,13 @@ def get_config_filenames(cls) -> List[str]:
8088
def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig":
8189
weight_bits = cls.get_from_keys(config, ["bits"])
8290
group_size = cls.get_from_keys(config, ["group_size"])
83-
has_zp = cls.get_from_keys(config, ["zero_point"])
91+
zero_point = cls.get_from_keys(config, ["zero_point"])
8492
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
8593
default=False)
86-
return cls(weight_bits, group_size, has_zp, lm_head_quantized)
94+
modules_to_not_convert = cls.get_from_keys_or(
95+
config, ["modules_to_not_convert"], None)
96+
return cls(weight_bits, group_size, zero_point, lm_head_quantized,
97+
modules_to_not_convert)
8798

8899
@classmethod
89100
def override_quantization_method(cls, hf_quant_cfg,
@@ -109,6 +120,8 @@ def get_quant_method(self, layer: torch.nn.Module,
109120
prefix: str) -> Optional["QuantizeMethodBase"]:
110121
if (isinstance(layer, LinearBase) or
111122
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
123+
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
124+
return UnquantizedLinearMethod()
112125
return AWQMarlinLinearMethod(self)
113126
elif isinstance(layer, FusedMoE):
114127
return AWQMoEMethod(self)
@@ -123,7 +136,7 @@ def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
123136
quant_method = quant_config.get("quant_method", "").lower()
124137
num_bits = quant_config.get("bits")
125138
group_size = quant_config.get("group_size")
126-
has_zp = quant_config.get("zero_point")
139+
zero_point = quant_config.get("zero_point")
127140

128141
if not current_platform.is_cuda():
129142
return False
@@ -132,15 +145,15 @@ def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
132145
return False
133146

134147
# If we cannot find the info needed in the config, cannot convert.
135-
if (num_bits is None or group_size is None or has_zp is None):
148+
if (num_bits is None or group_size is None or zero_point is None):
136149
return False
137150

138151
if num_bits not in cls.TYPE_MAP:
139152
return False
140153

141154
return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
142155
group_size=group_size,
143-
has_zp=has_zp)
156+
has_zp=zero_point)
144157

145158

146159
class AWQMarlinLinearMethod(LinearMethodBase):

0 commit comments

Comments
 (0)