13
13
14
14
from .audio import resample_audio
15
15
from .inputs import (AudioItem , HfAudioItem , HfImageItem , HfVideoItem ,
16
- ImageItem , ModalityData , MultiModalDataDict ,
17
- NestedTensors , VideoItem )
16
+ ImageItem , ModalityData , MultiModalDataDict , VideoItem )
18
17
19
18
_T = TypeVar ("_T" )
20
19
_I = TypeVar ("_I" )
21
20
22
21
23
22
class ModalityDataItems (ABC , Generic [_T , _I ]):
23
+ """
24
+ Represents data items for a modality in :class:`MultiModalDataItems`.
25
+ """
24
26
25
27
def __init__ (self , data : _T , modality : str ) -> None :
26
28
super ().__init__ ()
@@ -69,6 +71,7 @@ def get_passthrough_data(self) -> Mapping[str, object]:
69
71
70
72
71
73
class ProcessorBatchItems (ModalityDataItems [Sequence [_T ], _T ]):
74
+ """Base class for data items that are arranged in a list."""
72
75
73
76
def get_count (self ) -> int :
74
77
return len (self .data )
@@ -83,7 +86,12 @@ def get_passthrough_data(self) -> Mapping[str, object]:
83
86
return {}
84
87
85
88
86
- class EmbeddingItems (ModalityDataItems [NestedTensors , torch .Tensor ]):
89
+ class EmbeddingItems (ModalityDataItems [Union [torch .Tensor , list [torch .Tensor ]],
90
+ torch .Tensor ]):
91
+ """
92
+ Base class for data items that are expressed as a batched embedding tensor,
93
+ or a list of embedding tensors (one per item).
94
+ """
87
95
88
96
def get_count (self ) -> int :
89
97
return len (self .data )
@@ -109,7 +117,7 @@ def __init__(self, data: Sequence[HfAudioItem]) -> None:
109
117
110
118
class AudioEmbeddingItems (EmbeddingItems ):
111
119
112
- def __init__ (self , data : NestedTensors ) -> None :
120
+ def __init__ (self , data : Union [ torch . Tensor , list [ torch . Tensor ]] ) -> None :
113
121
super ().__init__ (data , "audio" )
114
122
115
123
@@ -137,7 +145,7 @@ def get_image_size(self, item_idx: int) -> ImageSize:
137
145
138
146
class ImageEmbeddingItems (EmbeddingItems ):
139
147
140
- def __init__ (self , data : NestedTensors ) -> None :
148
+ def __init__ (self , data : Union [ torch . Tensor , list [ torch . Tensor ]] ) -> None :
141
149
super ().__init__ (data , "image" )
142
150
143
151
@@ -163,7 +171,7 @@ def get_frame_size(self, item_idx: int) -> ImageSize:
163
171
164
172
class VideoEmbeddingItems (EmbeddingItems ):
165
173
166
- def __init__ (self , data : NestedTensors ) -> None :
174
+ def __init__ (self , data : Union [ torch . Tensor , list [ torch . Tensor ]] ) -> None :
167
175
super ().__init__ (data , "video" )
168
176
169
177
@@ -172,8 +180,8 @@ def __init__(self, data: NestedTensors) -> None:
172
180
173
181
class MultiModalDataItems (UserDict [str , ModalityDataItems [Any , Any ]]):
174
182
"""
175
- As :class:` MultiModalDataDict`, but normalized such that each entry
176
- corresponds to a list.
183
+ As :data:`~vllm.multimodal.inputs. MultiModalDataDict`, but normalized
184
+ such that each entry corresponds to a list.
177
185
"""
178
186
179
187
def get_count (self , modality : str , * , strict : bool = True ) -> int :
@@ -226,7 +234,8 @@ def get_items(
226
234
227
235
class MultiModalDataParser :
228
236
"""
229
- Parses :class:`MultiModalDataDict` into :class:`MultiModalDataItems`.
237
+ Parses :data:`~vllm.multimodal.inputs.MultiModalDataDict` into
238
+ :class:`MultiModalDataItems`.
230
239
231
240
Args:
232
241
target_sr (float, optional): Enables automatic resampling of audio
@@ -238,7 +247,9 @@ def __init__(self, *, target_sr: Optional[float] = None) -> None:
238
247
239
248
self .target_sr = target_sr
240
249
241
- def _is_embeddings (self , data : object ) -> TypeGuard [NestedTensors ]:
250
+ def _is_embeddings (
251
+ self , data : object
252
+ ) -> TypeGuard [Union [torch .Tensor , list [torch .Tensor ]]]:
242
253
if isinstance (data , torch .Tensor ):
243
254
return data .ndim == 3
244
255
if is_list_of (data , torch .Tensor ):
0 commit comments