@@ -20,10 +20,13 @@ class UltravoxConfig(transformers.PretrainedConfig):
20
20
21
21
Args:
22
22
audio_config (`Union[AutoConfig, dict]`, *optional*):
23
- Custom audio config or dict
23
+ Custom audio config or dict.
24
24
text_config (`Union[AutoConfig, dict]`, *optional*):
25
- The config object of the text backbone. Can be any of `LlamaConfig`
26
- or `MistralConfig`.
25
+ The config object of the text backbone.
26
+ audio_model_id (`str`, *optional*):
27
+ The model ID of the audio backbone.
28
+ text_model_id (`str`, *optional*):
29
+ The model ID of the text backbone.
27
30
ignore_index (`int`, *optional*, defaults to -100):
28
31
The ignore index for the loss function.
29
32
audio_token_index (`int`, *optional*, defaults to 32000):
@@ -60,15 +63,10 @@ def __init__(
60
63
stack_factor : int = 8 ,
61
64
norm_init : float = 0.4 ,
62
65
projector_act : str = "swiglu" ,
63
- text_model_lora_config : Optional [dict [str , Any ]] = None ,
64
- audio_model_lora_config : Optional [dict [str , Any ]] = None ,
65
66
projector_ln_mid : bool = False ,
66
67
** kwargs ,
67
68
):
68
69
self .ignore_index = ignore_index
69
-
70
- self .audio_model_id = audio_model_id
71
- self .text_model_id = text_model_id
72
70
self .audio_token_index = audio_token_index
73
71
74
72
self .hidden_size = hidden_size
@@ -78,35 +76,49 @@ def __init__(
78
76
self .projector_ln_mid = projector_ln_mid
79
77
80
78
if text_model_id is not None :
81
- # Avoid circular import
82
- from vllm .transformers_utils .config import get_config
83
-
84
- text_config_obj = get_config (text_model_id ,
85
- trust_remote_code = False )
79
+ # N.B. Sets the wrapped_model_config below.
80
+ self .text_model_id = text_model_id
86
81
else :
82
+ self .text_model_id = None
87
83
text_config = text_config or {}
88
- text_config_obj = transformers .CONFIG_MAPPING [text_config .get (
89
- "model_type" , "llama" )](** text_config )
90
-
91
- inner_text_config = text_config_obj .get_text_config ()
84
+ self .wrapped_model_config = transformers .CONFIG_MAPPING [
85
+ text_config .get ("model_type" , "llama" )](** text_config )
92
86
93
87
if audio_model_id is not None :
94
- # Avoid circular import
95
- from vllm .transformers_utils .config import get_config
96
-
97
- audio_config = get_config (audio_model_id , trust_remote_code = False )
88
+ # N.B. Sets the audio_config below.
89
+ self .audio_model_id = audio_model_id
98
90
else :
91
+ self .audio_model_id = None
99
92
audio_config = audio_config or {}
100
- audio_config = transformers .CONFIG_MAPPING [audio_config .get (
93
+ self . audio_config = transformers .CONFIG_MAPPING [audio_config .get (
101
94
"model_type" , "whisper" )](** audio_config )
102
95
103
- self .text_config = text_config_obj
104
- self .audio_config = audio_config
105
- self .text_model_lora_config = text_model_lora_config or {}
106
- self .audio_model_lora_config = audio_model_lora_config or {}
96
+ super ().__init__ (** kwargs )
107
97
108
- self .vocab_size = inner_text_config .vocab_size
109
- self .initializer_range = inner_text_config .initializer_range
110
- self .text_hidden_size = inner_text_config .hidden_size
98
+ def __setattr__ (self , key , value ):
99
+ # Since --hf-overrides are applied _after_ the UltravoxConfig is
100
+ # instantiated, load the configs implicitly when assigning text_model_id
101
+ # or audio_model_id. This allows:
102
+ #
103
+ # --hf-overrides.text_model_id=<quantized variant>
104
+ #
105
+ # to behave as intended.
106
+ if key == "text_model_id" and value is not None :
107
+ from vllm .transformers_utils .config import get_config
111
108
112
- super ().__init__ (** kwargs )
109
+ self .wrapped_model_config = get_config (value ,
110
+ trust_remote_code = False )
111
+ elif key == "audio_model_id" and value is not None :
112
+ from vllm .transformers_utils .config import get_config
113
+
114
+ self .audio_config = get_config (value , trust_remote_code = False )
115
+
116
+ return super ().__setattr__ (key , value )
117
+
118
+ @property
119
+ def text_config (self ) -> Optional [transformers .PretrainedConfig ]:
120
+ # When Ultravox wraps a multi-modal model (e.g. Gemma), we instantiate
121
+ # the full model, but the text config is the text config of the inner
122
+ # model.
123
+ return (self .wrapped_model_config .get_text_config ()
124
+ if self .wrapped_model_config else None )
0 commit comments