Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions vllm/model_executor/models/gemma3_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ class Gemma3ImagePixelInputs(TypedDict):
Shape: `(batch_size, num_images, num_embeds)`
"""

num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""


Gemma3ImageInputs = Gemma3ImagePixelInputs

Expand Down Expand Up @@ -319,11 +316,6 @@ def _call_hf_processor(
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
num_embeds = [
len(image_repl_feature_tokens)
for image_repl_feature_tokens in image_repls_feature_tokens
]
processed_outputs["num_embeds"] = torch.tensor(num_embeds)

vocab = tokenizer.get_vocab()
image_token_id = vocab[tokenizer.image_token]
Expand Down Expand Up @@ -356,7 +348,6 @@ def _get_mm_fields_config(
"image", num_crops + 1),
num_crops=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
)

def _get_prompt_updates(
Expand Down Expand Up @@ -585,7 +576,6 @@ def _parse_and_validate_image_input(
pixel_values = kwargs.pop("pixel_values", None)
num_crops = kwargs.pop("num_crops", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
num_embeds = kwargs.pop("num_embeds", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None:
Expand All @@ -603,10 +593,6 @@ def _parse_and_validate_image_input(
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")

if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")

pixel_values = flatten_bn(pixel_values, concat=True)
num_crops = flatten_bn(num_crops, concat=True)

Expand All @@ -615,7 +601,6 @@ def _parse_and_validate_image_input(
pixel_values=self._validate_pixel_values(pixel_values),
num_patches=num_crops + 1,
embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
)

def _image_pixels_to_features(
Expand Down Expand Up @@ -658,7 +643,6 @@ def get_multimodal_embeddings(
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["num_embeds"],
image_input["embed_is_patch"],
))

Expand Down
14 changes: 0 additions & 14 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ class InternVLImagePixelInputs(TypedDict):
Shape: `(batch_size, num_images, num_embeds)`
"""

num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""


class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
Expand Down Expand Up @@ -426,7 +423,6 @@ def __call__(
tokenizer = self.tokenizer
image_token_id = self.image_token_id

num_embeds = list[int]()
embed_is_patch = list[torch.Tensor]()

for pixel_values in pixel_values_lst:
Expand All @@ -438,11 +434,9 @@ def __call__(
add_special_tokens=False)

text = [t.replace('<image>', image_repl.full, 1) for t in text]
num_embeds.append(len(feature_tokens))
embed_is_patch.append(
torch.tensor(feature_tokens) == image_token_id)

image_inputs["num_embeds"] = torch.tensor(num_embeds)
image_inputs["embed_is_patch"] = embed_is_patch

text_inputs = self.tokenizer(text)
Expand Down Expand Up @@ -607,7 +601,6 @@ def _get_mm_fields_config(
"image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
)
Expand Down Expand Up @@ -840,7 +833,6 @@ def _parse_and_validate_image_input(
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
image_num_patches = kwargs.pop("image_num_patches", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
num_embeds = kwargs.pop("num_embeds", None)
image_embeds = kwargs.pop("image_embeds", None)

if pixel_values_flat is None and image_embeds is None:
Expand Down Expand Up @@ -873,10 +865,6 @@ def _parse_and_validate_image_input(
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")

if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")

pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)

Expand All @@ -886,7 +874,6 @@ def _parse_and_validate_image_input(
pixel_values_flat),
num_patches=image_num_patches,
embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
)

raise AssertionError("This line should be unreachable.")
Expand Down Expand Up @@ -941,7 +928,6 @@ def get_multimodal_embeddings(
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["num_embeds"],
image_input["embed_is_patch"],
))

Expand Down
16 changes: 0 additions & 16 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,6 @@ class PixtralHFImagePixelInputs(TypedDict):
Shape: `(batch_size, num_images, num_embeds)`
"""

num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""


class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
Expand Down Expand Up @@ -358,15 +355,10 @@ def _call_hf_processor(
image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"]
]
num_embeds = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes])
# Each image may result to masks of different sizes, so we need to
# later use `num_embeds` to get per-image masks.
embed_is_patch = [
torch.tensor(([True] * ncols + [False]) * nrows)
for ncols, nrows in tile_sizes
]
processed_outputs["num_embeds"] = num_embeds
processed_outputs["embed_is_patch"] = embed_is_patch

return processed_outputs
Expand All @@ -378,7 +370,6 @@ def _get_mm_fields_config(
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
Expand Down Expand Up @@ -627,16 +618,10 @@ def _parse_and_validate_image_input(
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")

num_embeds = kwargs.pop("num_embeds")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")

return PixtralHFImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
)

return LlavaImagePixelInputs(
Expand Down Expand Up @@ -738,7 +723,6 @@ def get_multimodal_embeddings(
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
vision_embeddings,
image_input["num_embeds"],
image_input["embed_is_patch"],
))

Expand Down
14 changes: 0 additions & 14 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,6 @@ class PixtralImagePixelInputs(TypedDict):
Shape: `(batch_size, num_images, num_embeds)`
"""

num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""


class PixtralProcessorAdapter:
"""
Expand Down Expand Up @@ -153,7 +150,6 @@ def __call__(
images_processed = list[torch.Tensor]()
images_tokens = list[torch.Tensor]()
images_embed_is_patch = list[torch.Tensor]()
images_num_embeds = list[int]()

for image in images:
image_inputs = self.image_processor(ImageChunk(image=image))
Expand All @@ -163,13 +159,11 @@ def __call__(
images_processed.append(image_processed)
images_tokens.append(image_tokens)
images_embed_is_patch.append(image_tokens == image_token_id)
images_num_embeds.append(len(image_tokens))

return {
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
"images": images_processed,
"embed_is_patch": images_embed_is_patch,
"num_embeds": torch.tensor(images_num_embeds),
}


Expand Down Expand Up @@ -273,7 +267,6 @@ def _get_mm_fields_config(
return dict(
images=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
)

def _get_prompt_updates(
Expand Down Expand Up @@ -394,16 +387,10 @@ def _parse_and_validate_image_input(
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")

num_embeds = kwargs.pop("num_embeds")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")

return PixtralImagePixelInputs(
type="pixel_values",
images=flatten_bn(images),
embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
)

def _process_image_input(
Expand Down Expand Up @@ -447,7 +434,6 @@ def get_multimodal_embeddings(
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["num_embeds"],
image_input["embed_is_patch"],
))

Expand Down
29 changes: 25 additions & 4 deletions vllm/model_executor/models/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ def resolve_visual_encoder_outputs(

def scatter_patch_features(
features: torch.Tensor,
num_embeds: torch.Tensor,
embed_is_patch: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
"""
Expand All @@ -168,13 +167,35 @@ def scatter_patch_features(
Args:
features: The patch features, concatenated across each image.
Shape: `(num_patch, feature_depth)`
num_embeds: The number of image embeddings for each image.
Shape: `(num_images,)`
embed_is_patch: A boolean mask indicating which image embeddings
correspond to patch tokens for each image.
Shape: `(num_images, num_embeds)`

Note:
The original code only considers patch tokens as feature
tokens, but our processor considers all image-related tokens
as feature tokens because the feature tokens need to be
consecutive in `input_ids`.

Example:
A simplified example for one image:

.. code-block::

Embedding tokens (from HF processor):
[<start> <patch> <patch> <col> <patch> <patch> <col> <end> ]

embed_is_patch (from HF processor):
[ False True True False True True False False ]

Encoder outputs (from model):
[ p1 p2 p3 p4 ]

The resulting embedding tensor is:
[ nan p1 p2 nan p3 p4 nan nan ]
"""
num_embeds_per_image: list[int] = num_embeds.tolist()
num_images, num_embeds = embed_is_patch.shape
num_embeds_per_image = [num_embeds] * num_images

embeds_flat = features.new_full(
(sum(num_embeds_per_image), features.shape[-1]),
Expand Down