12
12
from vllm .config import VllmConfig
13
13
from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
14
14
RowParallelLinear )
15
- from vllm .model_executor .layers .pooler import Pooler , PoolingType
15
+ from vllm .model_executor .layers .pooler import Pooler , PoolingType , SimplePooler
16
16
from vllm .model_executor .pooling_metadata import PoolingMetadata
17
17
from vllm .sequence import IntermediateTensors , PoolerOutput
18
18
@@ -32,7 +32,7 @@ def forward(self, input):
32
32
return self .activation (input )
33
33
34
34
35
- class Qwen2ForRewardModel (nn .Module , SupportsLoRA , SupportsPP ):
35
+ class Qwen2RewardBaseModel (nn .Module , SupportsLoRA , SupportsPP ):
36
36
packed_modules_mapping = {
37
37
"qkv_proj" : [
38
38
"q_proj" ,
@@ -60,7 +60,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
60
60
config = vllm_config .model_config .hf_config
61
61
quant_config = vllm_config .quant_config
62
62
lora_config = vllm_config .lora_config
63
- pooler_config = vllm_config .model_config .pooler_config
64
63
65
64
self .config = config
66
65
self .lora_config = lora_config
@@ -74,14 +73,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
74
73
config .hidden_size ,
75
74
quant_config = quant_config ),
76
75
ReLU (),
77
- RowParallelLinear (config .hidden_size , 1 ,
76
+ RowParallelLinear (config .hidden_size ,
77
+ config .num_labels ,
78
78
quant_config = quant_config ),
79
79
)
80
- self ._pooler = Pooler .from_config_with_defaults (
81
- pooler_config ,
82
- pooling_type = PoolingType .ALL ,
83
- normalize = False ,
84
- softmax = False )
80
+ self ._pooler : SimplePooler
85
81
self .make_empty_intermediate_tensors = (
86
82
self .model .make_empty_intermediate_tensors )
87
83
@@ -115,3 +111,31 @@ def load_weights(self, weights: Iterable[Tuple[str,
115
111
loader = AutoWeightsLoader (self ,
116
112
ignore_unexpected_prefixes = ["lm_head." ])
117
113
return loader .load_weights (weights )
114
+
115
+
116
+ class Qwen2ForRewardModel (Qwen2RewardBaseModel ):
117
+
118
+ def __init__ (self , * , vllm_config , prefix = "" ):
119
+ vllm_config .model_config .hf_config .num_labels = 1
120
+ super ().__init__ (vllm_config = vllm_config , prefix = prefix )
121
+ pooler_config = vllm_config .model_config .pooler_config
122
+ self ._pooler = Pooler .from_config_with_defaults (
123
+ pooler_config ,
124
+ pooling_type = PoolingType .ALL ,
125
+ normalize = False ,
126
+ softmax = False )
127
+
128
+
129
+ class Qwen2ForProcessRewardModel (Qwen2RewardBaseModel ):
130
+
131
+ def __init__ (self , * , vllm_config , prefix = "" ):
132
+ vllm_config .model_config .hf_config .num_labels = 2
133
+ super ().__init__ (vllm_config = vllm_config , prefix = prefix )
134
+ pooler_config = vllm_config .model_config .pooler_config
135
+ self ._pooler = Pooler .from_config_with_defaults (
136
+ pooler_config ,
137
+ pooling_type = PoolingType .STEP ,
138
+ normalize = False ,
139
+ softmax = True ,
140
+ step_tag_id = 151651 ,
141
+ )
0 commit comments