9
9
from vllm .model_executor .layers .fused_moe .layer import (
10
10
FusedMoE , FusedMoEMethodBase , FusedMoeWeightScaleSupported )
11
11
from vllm .model_executor .layers .linear import (LinearBase , LinearMethodBase ,
12
+ UnquantizedLinearMethod ,
12
13
set_weight_attrs )
14
+ from vllm .model_executor .layers .quantization .awq import is_layer_skipped_awq
13
15
from vllm .model_executor .layers .quantization .base_config import (
14
16
QuantizationConfig , QuantizeMethodBase )
15
17
from vllm .model_executor .layers .quantization .utils import replace_parameter
@@ -36,13 +38,18 @@ class AWQMarlinConfig(QuantizationConfig):
36
38
8 : scalar_types .uint8 ,
37
39
}
38
40
39
- def __init__ (self , weight_bits : int , group_size : int , has_zp : bool ,
40
- lm_head_quantized : bool ) -> None :
41
+ def __init__ (self ,
42
+ weight_bits : int ,
43
+ group_size : int ,
44
+ zero_point : bool ,
45
+ lm_head_quantized : bool ,
46
+ modules_to_not_convert : Optional [List [str ]] = None ) -> None :
41
47
self .pack_factor = 32 // weight_bits # packed into int32
42
48
self .group_size = group_size
43
- self .has_zp = has_zp
49
+ self .zero_point = zero_point
44
50
self .lm_head_quantized = lm_head_quantized
45
51
self .weight_bits = weight_bits
52
+ self .modules_to_not_convert = modules_to_not_convert or []
46
53
47
54
if self .weight_bits not in self .TYPE_MAP :
48
55
raise ValueError (f"Unsupported num_bits = { self .weight_bits } . "
@@ -52,13 +59,14 @@ def __init__(self, weight_bits: int, group_size: int, has_zp: bool,
52
59
53
60
verify_marlin_supported (self .quant_type ,
54
61
group_size = self .group_size ,
55
- has_zp = self .has_zp )
62
+ has_zp = self .zero_point )
56
63
57
64
def __repr__ (self ) -> str :
58
65
return (f"AWQMarlinConfig(quant_type={ self .quant_type } , "
59
66
f"group_size={ self .group_size } , "
60
- f"has_zp={ self .has_zp } , "
61
- f"lm_head_quantized={ self .lm_head_quantized } )" )
67
+ f"zero_point={ self .zero_point } , "
68
+ f"lm_head_quantized={ self .lm_head_quantized } , "
69
+ f"modules_to_not_convert={ self .modules_to_not_convert } )" )
62
70
63
71
@classmethod
64
72
def get_name (cls ) -> str :
@@ -80,10 +88,13 @@ def get_config_filenames(cls) -> List[str]:
80
88
def from_config (cls , config : Dict [str , Any ]) -> "AWQMarlinConfig" :
81
89
weight_bits = cls .get_from_keys (config , ["bits" ])
82
90
group_size = cls .get_from_keys (config , ["group_size" ])
83
- has_zp = cls .get_from_keys (config , ["zero_point" ])
91
+ zero_point = cls .get_from_keys (config , ["zero_point" ])
84
92
lm_head_quantized = cls .get_from_keys_or (config , ["lm_head" ],
85
93
default = False )
86
- return cls (weight_bits , group_size , has_zp , lm_head_quantized )
94
+ modules_to_not_convert = cls .get_from_keys_or (
95
+ config , ["modules_to_not_convert" ], None )
96
+ return cls (weight_bits , group_size , zero_point , lm_head_quantized ,
97
+ modules_to_not_convert )
87
98
88
99
@classmethod
89
100
def override_quantization_method (cls , hf_quant_cfg ,
@@ -109,6 +120,8 @@ def get_quant_method(self, layer: torch.nn.Module,
109
120
prefix : str ) -> Optional ["QuantizeMethodBase" ]:
110
121
if (isinstance (layer , LinearBase ) or
111
122
(isinstance (layer , ParallelLMHead ) and self .lm_head_quantized )):
123
+ if is_layer_skipped_awq (prefix , self .modules_to_not_convert ):
124
+ return UnquantizedLinearMethod ()
112
125
return AWQMarlinLinearMethod (self )
113
126
elif isinstance (layer , FusedMoE ):
114
127
return AWQMoEMethod (self )
@@ -123,7 +136,7 @@ def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
123
136
quant_method = quant_config .get ("quant_method" , "" ).lower ()
124
137
num_bits = quant_config .get ("bits" )
125
138
group_size = quant_config .get ("group_size" )
126
- has_zp = quant_config .get ("zero_point" )
139
+ zero_point = quant_config .get ("zero_point" )
127
140
128
141
if not current_platform .is_cuda ():
129
142
return False
@@ -132,15 +145,15 @@ def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
132
145
return False
133
146
134
147
# If we cannot find the info needed in the config, cannot convert.
135
- if (num_bits is None or group_size is None or has_zp is None ):
148
+ if (num_bits is None or group_size is None or zero_point is None ):
136
149
return False
137
150
138
151
if num_bits not in cls .TYPE_MAP :
139
152
return False
140
153
141
154
return check_marlin_supported (quant_type = cls .TYPE_MAP [num_bits ],
142
155
group_size = group_size ,
143
- has_zp = has_zp )
156
+ has_zp = zero_point )
144
157
145
158
146
159
class AWQMarlinLinearMethod (LinearMethodBase ):
0 commit comments