Skip to content

Commit 9282957

Browse files
committed
Migrate Gemma3ImagePixelInputs to TensorSchema
Signed-off-by: Benji Beck <[email protected]>
1 parent 57c22e5 commit 9282957

File tree

1 file changed

+21
-25
lines changed

1 file changed

+21
-25
lines changed

vllm/model_executor/models/gemma3_mm.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import math
44
from collections.abc import Iterable, Mapping, Sequence
5-
from typing import Any, Literal, Optional, TypedDict
5+
from typing import Annotated, Any, Literal, Optional
66

77
import torch
88
from torch import nn
@@ -31,6 +31,7 @@
3131
# yapf: enable
3232
from vllm.multimodal.profiling import BaseDummyInputsBuilder
3333
from vllm.sequence import IntermediateTensors
34+
from vllm.utils.tensor_schema import TensorSchema, TensorShape
3435

3536
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
3637
SupportsMultiModal, SupportsPP)
@@ -42,18 +43,21 @@
4243
logger = init_logger(__name__)
4344

4445

45-
class Gemma3ImagePixelInputs(TypedDict):
46-
type: Literal["pixel_values"]
47-
pixel_values: torch.Tensor
46+
class Gemma3ImagePixelInputs(TensorSchema):
4847
"""
49-
Shape: `(num_patches_total, num_channels, height, width)`
50-
51-
`num_patches_total` is the total number of patches
52-
over each image over each prompt in the batch.
48+
Dimensions:
49+
- p: Number of patches total (over each image over each prompt in the
50+
batch)
51+
- c: Number of channels (3)
52+
- h: Height of each patch
53+
- w: Width of each patch
54+
- bn: Batch size * number of images
5355
"""
56+
type: Literal["pixel_values"] = "pixel_values"
57+
58+
pixel_values: Annotated[torch.Tensor, TensorShape("p", 3, "h", "w")]
5459

55-
num_patches: torch.Tensor
56-
"""Shape: `(batch_size * num_images)`"""
60+
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
5761

5862

5963
Gemma3ImageInputs = Gemma3ImagePixelInputs
@@ -523,15 +527,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
523527
def dtype(self):
524528
return next(self.parameters()).dtype
525529

526-
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
527-
image_size = self.config.vision_config.image_size
528-
expected_dims = (3, image_size, image_size)
529-
if data.shape[1:] != expected_dims:
530-
raise ValueError(
531-
"The expected shape of pixel values per image per batch is "
532-
f"{expected_dims}. You supplied {tuple(data.shape)}.")
533-
return data
534-
535530
def _parse_and_validate_image_input(
536531
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
537532
pixel_values = kwargs.pop("pixel_values", None)
@@ -549,14 +544,15 @@ def _parse_and_validate_image_input(
549544
raise ValueError("Incorrect type of num_crops. "
550545
f"Got type: {type(num_crops)}")
551546

552-
pixel_values = flatten_bn(pixel_values, concat=True)
553-
num_crops = flatten_bn(num_crops, concat=True)
547+
image_size = self.config.vision_config.image_size
554548

555549
return Gemma3ImagePixelInputs(
556-
type="pixel_values",
557-
pixel_values=self._validate_pixel_values(pixel_values),
558-
num_patches=num_crops + 1,
559-
)
550+
pixel_values=flatten_bn(pixel_values, concat=True),
551+
num_patches=flatten_bn(num_crops, concat=True) + 1,
552+
resolve_bindings={
553+
"h": image_size,
554+
"w": image_size
555+
})
560556

561557
def _image_pixels_to_features(
562558
self,

0 commit comments

Comments
 (0)