Skip to content

Commit e3ad17d

Browse files
DarkLight1337mzusman
authored andcommitted
[Misc] Abstract the logic for reading and writing media content (vllm-project#11527)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent f402b6f commit e3ad17d

File tree

10 files changed

+493
-387
lines changed

10 files changed

+493
-387
lines changed

tests/entrypoints/openai/test_serving_chat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class MockModelConfig:
3333
hf_config = MockHFConfig()
3434
logits_processor_pattern = None
3535
diff_sampling_param: Optional[dict] = None
36+
allowed_local_media_path: str = ""
3637

3738
def get_diff_sampling_param(self):
3839
return self.diff_sampling_param or {}

tests/entrypoints/test_chat_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Optional
33

44
import pytest
5-
from PIL import Image
65

76
from vllm.assets.image import ImageAsset
87
from vllm.config import ModelConfig
@@ -91,10 +90,7 @@ def _assert_mm_data_is_image_input(
9190
image_data = mm_data.get("image")
9291
assert image_data is not None
9392

94-
if image_count == 1:
95-
assert isinstance(image_data, Image.Image)
96-
else:
97-
assert isinstance(image_data, list) and len(image_data) == image_count
93+
assert isinstance(image_data, list) and len(image_data) == image_count
9894

9995

10096
def test_parse_chat_messages_single_image(

tests/multimodal/test_utils.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from PIL import Image, ImageChops
1010
from transformers import AutoConfig, AutoTokenizer
1111

12-
from vllm.multimodal.utils import (async_fetch_image, fetch_image,
12+
from vllm.multimodal.utils import (MediaConnector,
1313
repeat_and_pad_placeholder_tokens)
1414

1515
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
@@ -23,7 +23,12 @@
2323

2424
@pytest.fixture(scope="module")
2525
def url_images() -> Dict[str, Image.Image]:
26-
return {image_url: fetch_image(image_url) for image_url in TEST_IMAGE_URLS}
26+
connector = MediaConnector()
27+
28+
return {
29+
image_url: connector.fetch_image(image_url)
30+
for image_url in TEST_IMAGE_URLS
31+
}
2732

2833

2934
def get_supported_suffixes() -> Tuple[str, ...]:
@@ -43,8 +48,10 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool:
4348
@pytest.mark.asyncio
4449
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
4550
async def test_fetch_image_http(image_url: str):
46-
image_sync = fetch_image(image_url)
47-
image_async = await async_fetch_image(image_url)
51+
connector = MediaConnector()
52+
53+
image_sync = connector.fetch_image(image_url)
54+
image_async = await connector.fetch_image_async(image_url)
4855
assert _image_equals(image_sync, image_async)
4956

5057

@@ -53,6 +60,7 @@ async def test_fetch_image_http(image_url: str):
5360
@pytest.mark.parametrize("suffix", get_supported_suffixes())
5461
async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
5562
image_url: str, suffix: str):
63+
connector = MediaConnector()
5664
url_image = url_images[image_url]
5765

5866
try:
@@ -75,48 +83,49 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
7583
base64_image = base64.b64encode(f.read()).decode("utf-8")
7684
data_url = f"data:{mime_type};base64,{base64_image}"
7785

78-
data_image_sync = fetch_image(data_url)
86+
data_image_sync = connector.fetch_image(data_url)
7987
if _image_equals(url_image, Image.open(f)):
8088
assert _image_equals(url_image, data_image_sync)
8189
else:
8290
pass # Lossy format; only check that image can be opened
8391

84-
data_image_async = await async_fetch_image(data_url)
92+
data_image_async = await connector.fetch_image_async(data_url)
8593
assert _image_equals(data_image_sync, data_image_async)
8694

8795

8896
@pytest.mark.asyncio
8997
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
9098
async def test_fetch_image_local_files(image_url: str):
99+
connector = MediaConnector()
100+
91101
with TemporaryDirectory() as temp_dir:
92-
origin_image = fetch_image(image_url)
102+
local_connector = MediaConnector(allowed_local_media_path=temp_dir)
103+
104+
origin_image = connector.fetch_image(image_url)
93105
origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)),
94106
quality=100,
95107
icc_profile=origin_image.info.get('icc_profile'))
96108

97-
image_async = await async_fetch_image(
98-
f"file://{temp_dir}/{os.path.basename(image_url)}",
99-
allowed_local_media_path=temp_dir)
100-
101-
image_sync = fetch_image(
102-
f"file://{temp_dir}/{os.path.basename(image_url)}",
103-
allowed_local_media_path=temp_dir)
109+
image_async = await local_connector.fetch_image_async(
110+
f"file://{temp_dir}/{os.path.basename(image_url)}")
111+
image_sync = local_connector.fetch_image(
112+
f"file://{temp_dir}/{os.path.basename(image_url)}")
104113
# Check that the images are equal
105114
assert not ImageChops.difference(image_sync, image_async).getbbox()
106115

107-
with pytest.raises(ValueError):
108-
await async_fetch_image(
109-
f"file://{temp_dir}/../{os.path.basename(image_url)}",
110-
allowed_local_media_path=temp_dir)
111-
with pytest.raises(ValueError):
112-
await async_fetch_image(
116+
with pytest.raises(ValueError, match="must be a subpath"):
117+
await local_connector.fetch_image_async(
118+
f"file://{temp_dir}/../{os.path.basename(image_url)}")
119+
with pytest.raises(RuntimeError, match="Cannot load local files"):
120+
await connector.fetch_image_async(
113121
f"file://{temp_dir}/../{os.path.basename(image_url)}")
114122

115-
with pytest.raises(ValueError):
116-
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}",
117-
allowed_local_media_path=temp_dir)
118-
with pytest.raises(ValueError):
119-
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}")
123+
with pytest.raises(ValueError, match="must be a subpath"):
124+
local_connector.fetch_image(
125+
f"file://{temp_dir}/../{os.path.basename(image_url)}")
126+
with pytest.raises(RuntimeError, match="Cannot load local files"):
127+
connector.fetch_image(
128+
f"file://{temp_dir}/../{os.path.basename(image_url)}")
120129

121130

122131
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])

vllm/assets/audio.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,10 @@ class AudioAsset:
2121
name: Literal["winning_call", "mary_had_lamb"]
2222

2323
@property
24-
def audio_and_sample_rate(self) -> tuple[npt.NDArray, int]:
24+
def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
2525
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
2626
s3_prefix=ASSET_DIR)
27-
y, sr = librosa.load(audio_path, sr=None)
28-
assert isinstance(sr, int)
29-
return y, sr
27+
return librosa.load(audio_path, sr=None)
3028

3129
@property
3230
def url(self) -> str:

0 commit comments

Comments
 (0)