Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/offline_inference/vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def load_h2ovl(question: str, image_urls: List[str]) -> ModelRequestData:
trust_remote_code=True,
max_model_len=8192,
limit_mm_per_prompt={"image": len(image_urls)},
mm_processor_kwargs={"max_dynamic_patch": 4},
)

placeholders = "\n".join(f"Image-{i}: <image>\n"
Expand Down
7 changes: 2 additions & 5 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.processing import ProcessingCache
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config

from ....multimodal.utils import random_audio, random_image, random_video
from ...registry import HF_EXAMPLE_MODELS
Expand Down Expand Up @@ -42,10 +42,7 @@ def _test_processing_correctness(
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext(
model_config,
tokenizer=cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_info.trust_remote_code,
),
tokenizer=cached_tokenizer_from_config(model_config),
)
# Ensure that it can fit all of the data
cache = ProcessingCache(capacity=1 << 30)
Expand Down
225 changes: 131 additions & 94 deletions tests/models/multimodal/processing/test_h2ovl.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for H2OVL's multimodal preprocessing kwargs."""
from typing import Optional
from typing import Mapping, Optional

import pytest
from PIL import Image
from transformers import PretrainedConfig

from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import rescale_image_size
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config

from ....conftest import _ImageAssets
from ...utils import build_model_context


def _get_expected_num_patches(
config: PretrainedConfig,
image: Image.Image,
num_imgs: int,
min_num: int,
max_num: int,
):
from vllm.model_executor.models.h2ovl import (calculate_h2ovl_targets,
get_h2ovl_target_ratios)

width, height = image.size

# Calculate the expected number of blocks
if num_imgs == 1 and config.use_msac:
# First pass
blocks1, _, _, aspect_ratio = calculate_h2ovl_targets(
orig_width=width,
orig_height=height,
target_ratios=get_h2ovl_target_ratios(
min_num=1,
max_num=max_num,
prior_aspect_ratio=None,
),
image_size=config.vision_config.image_size,
use_thumbnail=False, # Thumbnail is handled separately
)

# Second pass
blocks2, _, _, _ = calculate_h2ovl_targets(
orig_width=width,
orig_height=height,
target_ratios=get_h2ovl_target_ratios(
min_num=3,
max_num=max_num,
prior_aspect_ratio=aspect_ratio,
),
image_size=config.vision_config.image_size,
use_thumbnail=False,
)

# Add thumbnail if use_thumbnail is True and total_blocks > 1
if config.use_thumbnail:
blocks1 += 1 if blocks1 > 1 else 0
blocks2 += 1 if blocks2 > 1 else 0

# Total blocks is the sum of blocks from both passes minus
# overlapping
total_blocks = blocks1 + blocks2 - 1

return total_blocks

blocks, _, _, _ = calculate_h2ovl_targets(
orig_width=width,
orig_height=height,
target_ratios=get_h2ovl_target_ratios(
min_num,
max_num,
prior_aspect_ratio=None,
),
image_size=config.vision_config.image_size,
use_thumbnail=False,
)
expected_num_patches = blocks

if config.use_thumbnail and expected_num_patches > 1:
expected_num_patches += 1

return expected_num_patches


def _run_check(
processor: BaseMultiModalProcessor,
images: list[Image.Image],
min_num: int,
max_num: int,
mm_processor_kwargs: Mapping[str, object],
):
tokenizer = processor.info.get_tokenizer()
config = processor.info.get_hf_config()

mm_data = {"image": images}

total_expected_num_patches = sum(
_get_expected_num_patches(config, image, len(images), min_num, max_num)
for image in images)

processed_inputs = processor.apply("<image>" * len(images), mm_data,
mm_processor_kwargs)

# Ensure we have the right number of placeholders per num_crops size
image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>")
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape

assert img_tok_count == 256 * total_expected_num_patches
assert pixel_shape[0] == total_expected_num_patches


@pytest.mark.parametrize("model_id", [
"h2oai/h2ovl-mississippi-800m",
"h2oai/h2ovl-mississippi-2b",
Expand All @@ -25,118 +126,54 @@
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
[4.0, 2.0, 1.0],
],
)
@pytest.mark.parametrize("max_dynamic_patch", [1, 2, 4, 8])
@pytest.mark.parametrize(
("min_dynamic_patch", "max_dynamic_patch"),
[(1, 1), (1, 2), (1, 4), (1, 8), (2, 4), (4, 8)],
)
@pytest.mark.parametrize("dynamic_image_size", [True, False])
@pytest.mark.parametrize("num_imgs", [1, 2])
@pytest.mark.parametrize("kwargs_on_init", [True, False])
def test_processor_override(
model_id: str,
image_assets: _ImageAssets,
size_factors: list[int],
min_dynamic_patch: int,
max_dynamic_patch: int,
dynamic_image_size: Optional[bool],
num_imgs: int,
kwargs_on_init: bool,
):
from vllm.model_executor.models.h2ovl import (calculate_h2ovl_targets,
get_h2ovl_target_ratios)
mm_processor_kwargs = {
"min_dynamic_patch": min_dynamic_patch,
"max_dynamic_patch": max_dynamic_patch,
"dynamic_image_size": dynamic_image_size,
}

ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
trust_remote_code=True,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code,
mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None,
limit_mm_per_prompt={"image": len(size_factors)},
)
tokenizer = cached_tokenizer_from_config(ctx.model_config)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,
)
hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs

config = processor.info.get_hf_config()
use_msac = config.use_msac

mm_processor_kwargs = {
"max_dynamic_patch": max_dynamic_patch,
}
if dynamic_image_size is not None:
mm_processor_kwargs["dynamic_image_size"] = dynamic_image_size

min_num = config.min_dynamic_patch
min_num = min_dynamic_patch if dynamic_image_size else 1
max_num = max_dynamic_patch if dynamic_image_size else 1

# Build the image str / prompt based on the number of images we pass
prompt = "<image>" * num_imgs

for asset in image_assets:
for factor in size_factors:
image = rescale_image_size(asset.pil_image, factor)
mm_data = {"image": [image] * num_imgs}

width, height = image.size

# Calculate the expected number of blocks
if num_imgs == 1 and use_msac:
# First pass
blocks1, _, _, aspect_ratio = calculate_h2ovl_targets(
orig_width=width,
orig_height=height,
target_ratios=get_h2ovl_target_ratios(
min_num,
max_num,
prior_aspect_ratio=None,
),
image_size=config.vision_config.image_size,
use_thumbnail=False, # Thumbnail is handled separately
)

# Second pass
blocks2, _, _, _ = calculate_h2ovl_targets(
orig_width=width,
orig_height=height,
target_ratios=get_h2ovl_target_ratios(
min_num,
max_num,
prior_aspect_ratio=aspect_ratio,
),
image_size=config.vision_config.image_size,
use_thumbnail=False,
)

# Add thumbnail if use_thumbnail is True and total_blocks > 1
if config.use_thumbnail:
blocks1 += 1 if blocks1 > 1 else 0
blocks2 += 1 if blocks2 > 1 else 0

# Total blocks is the sum of blocks from both passes minus
# overlapping
total_blocks = blocks1 + blocks2 - 1

expected_num_patches = total_blocks
else:
blocks, _, _, _ = calculate_h2ovl_targets(
orig_width=width,
orig_height=height,
target_ratios=get_h2ovl_target_ratios(
min_num,
max_num,
prior_aspect_ratio=None,
),
image_size=config.vision_config.image_size,
use_thumbnail=False,
)
expected_num_patches = blocks

if config.use_thumbnail and expected_num_patches != 1:
expected_num_patches += 1

processed_inputs = processor.apply(prompt, mm_data,
mm_processor_kwargs)
pixel_shape = (
processed_inputs["mm_kwargs"]["pixel_values_flat"].shape)

assert pixel_shape[0] == expected_num_patches * num_imgs
_run_check(
processor,
[
rescale_image_size(image_assets[0].pil_image, f)
for f in size_factors
],
min_num,
max_num,
hf_processor_mm_kwargs,
)
24 changes: 16 additions & 8 deletions tests/models/multimodal/processing/test_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from transformers import Idefics3Config

from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config

from ....conftest import _ImageAssets
from ...utils import build_model_context
Expand All @@ -22,9 +22,15 @@
])
# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(image_assets: _ImageAssets, model: str,
mm_processor_kwargs: dict[str, object],
expected_toks_per_img: int, num_imgs: int):
@pytest.mark.parametrize("kwargs_on_init", [True, False])
def test_processor_override(
image_assets: _ImageAssets,
model: str,
mm_processor_kwargs: dict[str, object],
expected_toks_per_img: int,
num_imgs: int,
kwargs_on_init: bool,
):
"""Ensure input_processor_for_idefics3 handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
Expand All @@ -33,15 +39,15 @@ def test_processor_override(image_assets: _ImageAssets, model: str,
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
tokenizer = cached_tokenizer_from_config(ctx.model_config)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,
)
hf_processor = processor.info.get_hf_processor(**mm_processor_kwargs)
hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs

# Build the image str / prompt based on the number of images we pass
placeholders = "<image>" if num_imgs == 1 else "\n".join(
Expand All @@ -54,8 +60,10 @@ def test_processor_override(image_assets: _ImageAssets, model: str,
dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
mm_data = {"image": [dummy_image] * num_imgs}

processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)

# Ensure the placeholders format are correct
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
hf_processed_inputs = hf_processor(text=prompt, images=mm_data["image"])
assert processed_inputs["prompt_token_ids"] == hf_processed_inputs[
"input_ids"][0]
Expand Down
Loading