2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
import math
4
4
from collections .abc import Iterable , Mapping , Sequence
5
- from typing import Any , Literal , Optional , TypedDict
5
+ from typing import Annotated , Any , Literal , Optional
6
6
7
7
import torch
8
8
from torch import nn
31
31
# yapf: enable
32
32
from vllm .multimodal .profiling import BaseDummyInputsBuilder
33
33
from vllm .sequence import IntermediateTensors
34
+ from vllm .utils .tensor_schema import TensorSchema , TensorShape
34
35
35
36
from .interfaces import (MultiModalEmbeddings , SupportsLoRA ,
36
37
SupportsMultiModal , SupportsPP )
42
43
logger = init_logger (__name__ )
43
44
44
45
45
- class Gemma3ImagePixelInputs (TypedDict ):
46
- type : Literal ["pixel_values" ]
47
- pixel_values : torch .Tensor
46
+ class Gemma3ImagePixelInputs (TensorSchema ):
48
47
"""
49
- Shape: `(num_patches_total, num_channels, height, width)`
50
-
51
- `num_patches_total` is the total number of patches
52
- over each image over each prompt in the batch.
48
+ Dimensions:
49
+ - p: Number of patches total (over each image over each prompt in the
50
+ batch)
51
+ - c: Number of channels (3)
52
+ - h: Height of each patch
53
+ - w: Width of each patch
54
+ - bn: Batch size * number of images
53
55
"""
56
+ type : Literal ["pixel_values" ] = "pixel_values"
57
+
58
+ pixel_values : Annotated [torch .Tensor , TensorShape ("p" , 3 , "h" , "w" )]
54
59
55
- num_patches : torch .Tensor
56
- """Shape: `(batch_size * num_images)`"""
60
+ num_patches : Annotated [torch .Tensor , TensorShape ("bn" )]
57
61
58
62
59
63
Gemma3ImageInputs = Gemma3ImagePixelInputs
@@ -523,15 +527,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
523
527
def dtype (self ):
524
528
return next (self .parameters ()).dtype
525
529
526
- def _validate_pixel_values (self , data : torch .Tensor ) -> torch .Tensor :
527
- image_size = self .config .vision_config .image_size
528
- expected_dims = (3 , image_size , image_size )
529
- if data .shape [1 :] != expected_dims :
530
- raise ValueError (
531
- "The expected shape of pixel values per image per batch is "
532
- f"{ expected_dims } . You supplied { tuple (data .shape )} ." )
533
- return data
534
-
535
530
def _parse_and_validate_image_input (
536
531
self , ** kwargs : object ) -> Optional [Gemma3ImageInputs ]:
537
532
pixel_values = kwargs .pop ("pixel_values" , None )
@@ -549,14 +544,15 @@ def _parse_and_validate_image_input(
549
544
raise ValueError ("Incorrect type of num_crops. "
550
545
f"Got type: { type (num_crops )} " )
551
546
552
- pixel_values = flatten_bn (pixel_values , concat = True )
553
- num_crops = flatten_bn (num_crops , concat = True )
547
+ image_size = self .config .vision_config .image_size
554
548
555
549
return Gemma3ImagePixelInputs (
556
- type = "pixel_values" ,
557
- pixel_values = self ._validate_pixel_values (pixel_values ),
558
- num_patches = num_crops + 1 ,
559
- )
550
+ pixel_values = flatten_bn (pixel_values , concat = True ),
551
+ num_patches = flatten_bn (num_crops , concat = True ) + 1 ,
552
+ resolve_bindings = {
553
+ "h" : image_size ,
554
+ "w" : image_size
555
+ })
560
556
561
557
def _image_pixels_to_features (
562
558
self ,
0 commit comments