|
7 | 7 | # Licensed under The MIT License [see LICENSE for details]
|
8 | 8 | # --------------------------------------------------------
|
9 | 9 | from collections.abc import Iterable, Mapping, Sequence
|
10 |
| -from typing import Literal, Optional, TypedDict, Union |
| 10 | +from typing import Annotated, Literal, Optional, Union |
11 | 11 |
|
12 | 12 | import regex as re
|
13 | 13 | import torch
|
|
32 | 32 | PromptUpdate, PromptUpdateDetails)
|
33 | 33 | from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
34 | 34 | from vllm.sequence import IntermediateTensors
|
| 35 | +from vllm.utils.tensor_schema import TensorSchema, TensorShape |
35 | 36 |
|
36 | 37 | from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
37 | 38 | SupportsMultiModal, SupportsPP)
|
@@ -62,51 +63,60 @@ def forward(self, image_features):
|
62 | 63 | return hidden_states
|
63 | 64 |
|
64 | 65 |
|
65 |
| -class InternS1ImagePixelInputs(TypedDict): |
66 |
| - type: Literal["pixel_values"] |
67 |
| - pixel_values: torch.Tensor |
| 66 | +class InternS1ImagePixelInputs(TensorSchema): |
68 | 67 | """
|
69 |
| - Shape: |
70 |
| - `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` |
| 68 | + Dimensions: |
| 69 | + - bnp: Batch size * number of images * (1 + num_patches) |
| 70 | + - c: Number of channels (3) |
| 71 | + - h: Height |
| 72 | + - w: Width |
| 73 | + - bn: Batch size * number of images |
71 | 74 | """
|
| 75 | + type: Literal["pixel_values"] = "pixel_values" |
| 76 | + pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")] |
| 77 | + num_patches: Annotated[torch.Tensor, TensorShape("bn")] |
72 | 78 |
|
73 | 79 |
|
74 |
| -class InternS1ImageEmbeddingInputs(TypedDict): |
75 |
| - type: Literal["image_embeds"] |
76 |
| - data: Union[torch.Tensor, list[torch.Tensor]] |
| 80 | +class InternS1ImageEmbeddingInputs(TensorSchema): |
77 | 81 | """
|
78 |
| - A tensor of shape `(num_images, total_image_feature_size, hidden_size)` |
79 |
| - or a list of tensors of shape `(total_image_feature_size, hidden_size)` |
80 |
| -
|
81 |
| - `hidden_size` must match the hidden size of language model backbone. |
| 82 | + Dimensions: |
| 83 | + - ni: Number of images |
| 84 | + - tifs: Total image feature size |
| 85 | + - hs: Hidden size (must match language model backbone) |
82 | 86 | """
|
| 87 | + type: Literal["image_embeds"] = "image_embeds" |
| 88 | + data: Annotated[Union[torch.Tensor, list[torch.Tensor]], |
| 89 | + TensorShape("ni", "tifs", "hs")] |
83 | 90 |
|
84 | 91 |
|
85 | 92 | InternS1ImageInputs = Union[InternS1ImagePixelInputs,
|
86 | 93 | InternS1ImageEmbeddingInputs]
|
87 | 94 |
|
88 | 95 |
|
89 |
| -class InternS1VideoPixelInputs(TypedDict): |
90 |
| - type: Literal["pixel_values_videos"] |
91 |
| - pixel_values: torch.Tensor |
| 96 | +class InternS1VideoPixelInputs(TensorSchema): |
92 | 97 | """
|
93 |
| - Shape: |
94 |
| - `(batch_size * num_video * num_frames, num_channels, height, width)` |
| 98 | + Dimensions: |
| 99 | + - bnv: Batch size * number of videos * number of frames |
| 100 | + - bn: Batch size * number of images |
| 101 | + - c: Number of channels (3) |
| 102 | + - h: Height |
| 103 | + - w: Width |
95 | 104 | """
|
96 |
| - |
97 |
| - num_patches: torch.Tensor |
98 |
| - """Shape: `(batch_size * num_images)`""" |
| 105 | + type: Literal["pixel_values_videos"] = "pixel_values_videos" |
| 106 | + pixel_values: Annotated[torch.Tensor, TensorShape("bnv", 3, "h", "w")] |
| 107 | + num_patches: Annotated[torch.Tensor, TensorShape("bn")] |
99 | 108 |
|
100 | 109 |
|
101 |
| -class InternS1VideoEmbeddingInputs(TypedDict): |
102 |
| - type: Literal["video_embeds"] |
103 |
| - data: Union[torch.Tensor, list[torch.Tensor]] |
| 110 | +class InternS1VideoEmbeddingInputs(TensorSchema): |
104 | 111 | """
|
105 |
| - A tensor of shape `(num_videos, total_video_feature_size, hidden_size)` |
106 |
| - or a list of tensors of shape `(total_video_feature_size, hidden_size)` |
107 |
| -
|
108 |
| - `hidden_size` must match the hidden size of language model backbone. |
| 112 | + Dimensions: |
| 113 | + - nv: Number of videos |
| 114 | + - tvfs: Total video feature size |
| 115 | + - hs: Hidden size (must match language model backbone) |
109 | 116 | """
|
| 117 | + type: Literal["video_embeds"] = "video_embeds" |
| 118 | + data: Annotated[Union[torch.Tensor, list[torch.Tensor]], |
| 119 | + TensorShape("nv", "tvfs", "hs")] |
110 | 120 |
|
111 | 121 |
|
112 | 122 | InternS1VideoInputs = Union[InternS1VideoPixelInputs,
|
@@ -572,26 +582,6 @@ def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
572 | 582 | vit_embeds = self.multi_modal_projector(vit_embeds)
|
573 | 583 | return vit_embeds
|
574 | 584 |
|
575 |
| - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: |
576 |
| - |
577 |
| - h, w = self.config.vision_config.image_size |
578 |
| - expected_dims = (3, h, w) |
579 |
| - |
580 |
| - def _validate_shape(d: torch.Tensor): |
581 |
| - actual_dims = tuple(d.shape) |
582 |
| - |
583 |
| - if actual_dims != expected_dims: |
584 |
| - expected_expr = str(expected_dims) |
585 |
| - raise ValueError( |
586 |
| - "The expected shape of pixel values per image per batch " |
587 |
| - f" per patch is {expected_expr}. " |
588 |
| - f"You supplied {tuple(d.shape)}.") |
589 |
| - |
590 |
| - for d in data: |
591 |
| - _validate_shape(d) |
592 |
| - |
593 |
| - return data |
594 |
| - |
595 | 585 | def _parse_and_validate_image_input(
|
596 | 586 | self, **kwargs: object) -> Optional[InternS1ImageInputs]:
|
597 | 587 | pixel_values = kwargs.pop("pixel_values", None)
|
@@ -627,10 +617,15 @@ def _parse_and_validate_image_input(
|
627 | 617 | pixel_values = flatten_bn(pixel_values, concat=True)
|
628 | 618 | image_num_patches = flatten_bn(image_num_patches, concat=True)
|
629 | 619 |
|
| 620 | + h, w = self.config.vision_config.image_size |
630 | 621 | return InternS1ImagePixelInputs(
|
631 | 622 | type="pixel_values",
|
632 |
| - pixel_values=self._validate_pixel_values(pixel_values), |
| 623 | + pixel_values=pixel_values, |
633 | 624 | num_patches=image_num_patches,
|
| 625 | + resolve_bindings={ |
| 626 | + "h": h, |
| 627 | + "w": w, |
| 628 | + }, |
634 | 629 | )
|
635 | 630 |
|
636 | 631 | raise AssertionError("This line should be unreachable.")
|
@@ -671,11 +666,15 @@ def _parse_and_validate_video_input(
|
671 | 666 | concat=True)
|
672 | 667 | video_num_patches = flatten_bn(video_num_patches, concat=True)
|
673 | 668 |
|
| 669 | + h, w = self.config.vision_config.image_size |
674 | 670 | return InternS1VideoPixelInputs(
|
675 | 671 | type="pixel_values_videos",
|
676 |
| - pixel_values=self._validate_pixel_values( |
677 |
| - pixel_values_flat_video), |
678 | 672 | num_patches=video_num_patches,
|
| 673 | + pixel_values=pixel_values_flat_video, |
| 674 | + resolve_bindings={ |
| 675 | + "h": h, |
| 676 | + "w": w, |
| 677 | + }, |
679 | 678 | )
|
680 | 679 |
|
681 | 680 | raise AssertionError("This line should be unreachable.")
|
|
0 commit comments