Skip to content

Commit 7331c69

Browse files
committed
[Ultravox] Fix Gemma instantiation, support quantization via --hf-override.text_model_id
Signed-off-by: Peter Salas <[email protected]>
1 parent 2bef2d1 commit 7331c69

File tree

3 files changed

+51
-35
lines changed

3 files changed

+51
-35
lines changed

vllm/config/__init__.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,11 +1092,11 @@ def _get_supported_tasks(
10921092

10931093
assert_never(runner_type)
10941094

1095-
def _parse_quant_hf_config(self):
1096-
quant_cfg = getattr(self.hf_config, "quantization_config", None)
1095+
def _parse_quant_hf_config(self, hf_config: PretrainedConfig):
1096+
quant_cfg = getattr(hf_config, "quantization_config", None)
10971097
if quant_cfg is None:
10981098
# compressed-tensors uses a "compression_config" key
1099-
quant_cfg = getattr(self.hf_config, "compression_config", None)
1099+
quant_cfg = getattr(hf_config, "compression_config", None)
11001100

11011101
else:
11021102
# Set quant_method for ModelOpt models.
@@ -1137,7 +1137,11 @@ def _verify_quantization(self) -> None:
11371137
self.quantization)
11381138

11391139
# Parse quantization method from the HF model config, if available.
1140-
quant_cfg = self._parse_quant_hf_config()
1140+
quant_cfg = self._parse_quant_hf_config(self.hf_config)
1141+
if quant_cfg is None and (text_config := getattr(
1142+
self.hf_config, "text_config", None)):
1143+
# Check if the text config as well for multi-modal models.
1144+
quant_cfg = self._parse_quant_hf_config(text_config)
11411145

11421146
if quant_cfg is not None:
11431147
# Use the community standard 'quant_method'

vllm/model_executor/models/ultravox.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def __init__(self, config: UltravoxConfig):
276276
else:
277277
self.act = get_act_fn(config.projector_act)
278278

279-
dim_out = config.text_hidden_size
279+
dim_out = config.text_config.hidden_size
280280
self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)
281281

282282
# Ultravox v0.4.1 and below use layer_norm after the second linear layer

vllm/transformers_utils/configs/ultravox.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@ class UltravoxConfig(transformers.PretrainedConfig):
2020
2121
Args:
2222
audio_config (`Union[AutoConfig, dict]`, *optional*):
23-
Custom audio config or dict
23+
Custom audio config or dict.
2424
text_config (`Union[AutoConfig, dict]`, *optional*):
25-
The config object of the text backbone. Can be any of `LlamaConfig`
26-
or `MistralConfig`.
25+
The config object of the text backbone.
26+
audio_model_id (`str`, *optional*):
27+
The model ID of the audio backbone.
28+
text_model_id (`str`, *optional*):
29+
The model ID of the text backbone.
2730
ignore_index (`int`, *optional*, defaults to -100):
2831
The ignore index for the loss function.
2932
audio_token_index (`int`, *optional*, defaults to 32000):
@@ -60,15 +63,10 @@ def __init__(
6063
stack_factor: int = 8,
6164
norm_init: float = 0.4,
6265
projector_act: str = "swiglu",
63-
text_model_lora_config: Optional[dict[str, Any]] = None,
64-
audio_model_lora_config: Optional[dict[str, Any]] = None,
6566
projector_ln_mid: bool = False,
6667
**kwargs,
6768
):
6869
self.ignore_index = ignore_index
69-
70-
self.audio_model_id = audio_model_id
71-
self.text_model_id = text_model_id
7270
self.audio_token_index = audio_token_index
7371

7472
self.hidden_size = hidden_size
@@ -78,35 +76,49 @@ def __init__(
7876
self.projector_ln_mid = projector_ln_mid
7977

8078
if text_model_id is not None:
81-
# Avoid circular import
82-
from vllm.transformers_utils.config import get_config
83-
84-
text_config_obj = get_config(text_model_id,
85-
trust_remote_code=False)
79+
# N.B. Sets the wrapped_model_config below.
80+
self.text_model_id = text_model_id
8681
else:
82+
self.text_model_id = None
8783
text_config = text_config or {}
88-
text_config_obj = transformers.CONFIG_MAPPING[text_config.get(
89-
"model_type", "llama")](**text_config)
90-
91-
inner_text_config = text_config_obj.get_text_config()
84+
self.wrapped_model_config = transformers.CONFIG_MAPPING[
85+
text_config.get("model_type", "llama")](**text_config)
9286

9387
if audio_model_id is not None:
94-
# Avoid circular import
95-
from vllm.transformers_utils.config import get_config
96-
97-
audio_config = get_config(audio_model_id, trust_remote_code=False)
88+
# N.B. Sets the audio_config below.
89+
self.audio_model_id = audio_model_id
9890
else:
91+
self.audio_model_id = None
9992
audio_config = audio_config or {}
100-
audio_config = transformers.CONFIG_MAPPING[audio_config.get(
93+
self.audio_config = transformers.CONFIG_MAPPING[audio_config.get(
10194
"model_type", "whisper")](**audio_config)
10295

103-
self.text_config = text_config_obj
104-
self.audio_config = audio_config
105-
self.text_model_lora_config = text_model_lora_config or {}
106-
self.audio_model_lora_config = audio_model_lora_config or {}
96+
super().__init__(**kwargs)
10797

108-
self.vocab_size = inner_text_config.vocab_size
109-
self.initializer_range = inner_text_config.initializer_range
110-
self.text_hidden_size = inner_text_config.hidden_size
98+
def __setattr__(self, key, value):
99+
# Since --hf-overrides are applied _after_ the UltravoxConfig is
100+
# instantiated, load the configs implicitly when assigning text_model_id
101+
# or audio_model_id. This allows:
102+
#
103+
# --hf-overrides.text_model_id=<quantized variant>
104+
#
105+
# to behave as intended.
106+
if key == "text_model_id" and value is not None:
107+
from vllm.transformers_utils.config import get_config
111108

112-
super().__init__(**kwargs)
109+
self.wrapped_model_config = get_config(value,
110+
trust_remote_code=False)
111+
elif key == "audio_model_id" and value is not None:
112+
from vllm.transformers_utils.config import get_config
113+
114+
self.audio_config = get_config(value, trust_remote_code=False)
115+
116+
return super().__setattr__(key, value)
117+
118+
@property
119+
def text_config(self) -> Optional[transformers.PretrainedConfig]:
120+
# When Ultravox wraps a multi-modal model (e.g. Gemma), we instantiate
121+
# the full model, but the text config is the text config of the inner
122+
# model.
123+
return (self.wrapped_model_config.get_text_config()
124+
if self.wrapped_model_config else None)

0 commit comments

Comments
 (0)