Skip to content

Commit 28907ec

Browse files
jeejeeleerasmith
authored andcommitted
[Bugfix][V1] Fix molmo text-only inputs (vllm-project#11676)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent 1ad90f8 commit 28907ec

File tree

3 files changed

+123
-42
lines changed

3 files changed

+123
-42
lines changed

tests/models/decoder_only/vision_language/test_models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,16 @@
341341
),
342342
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
343343
),
344+
"molmo": VLMTestInfo(
345+
models=["allenai/Molmo-7B-D-0924"],
346+
test_type=(VLMTestType.IMAGE),
347+
prompt_formatter=lambda img_prompt:"User: " + img_prompt + " Assistant:", # noqa: E501
348+
max_model_len=4096,
349+
max_num_seqs=2,
350+
image_size_factors=[(),(1.0, 1.0, 1.0)],
351+
patch_hf_runner=model_utils.mlomo_patch_hf_runner,
352+
postprocess_inputs=model_utils.molmo_post_processor,
353+
),
344354
# Tests for phi3v currently live in another file because of a bug in
345355
# transformers. Once this issue is fixed, we can enable them here instead.
346356
# https://github.com/huggingface/transformers/issues/34307

tests/models/decoder_only/vision_language/vlm_utils/model_utils.py

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,20 @@
55
import re
66
import types
77
from pathlib import PosixPath
8-
from typing import Callable, List, Optional, Tuple, Union
8+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
99

1010
import torch
1111
from PIL.Image import Image
12-
from transformers import AutoConfig, AutoTokenizer, BatchEncoding
12+
from transformers import (AutoConfig, AutoTokenizer, BatchEncoding,
13+
GenerationConfig)
1314

1415
from vllm.sequence import SampleLogprobs
1516
from vllm.transformers_utils.tokenizer import patch_padding_side
1617
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
1718

18-
from .....conftest import HfRunner, ImageAsset, _ImageAssets
19+
from .....conftest import (HfRunner, ImageAsset, PromptAudioInput,
20+
PromptImageInput, PromptVideoInput, _ImageAssets)
21+
from ....utils import TokensTextLogprobs
1922
from .types import RunnerOutput
2023

2124

@@ -222,6 +225,11 @@ def wrap_inputs_post_processor(hf_inputs: BatchEncoding, dtype: str):
222225
return {"model_inputs": hf_inputs}
223226

224227

228+
def molmo_post_processor(hf_inputs: BatchEncoding, dtype: str):
229+
hf_inputs = cast_dtype_post_processor("images")(hf_inputs, dtype)
230+
return {k: v.unsqueeze(0) for k, v in hf_inputs.items()}
231+
232+
225233
####### Prompt path encoders for models that need models on disk
226234
def qwen_prompt_path_encoder(
227235
tmp_path: PosixPath, prompt: str, assets: Union[List[ImageAsset],
@@ -451,3 +459,88 @@ def _generate(self, *args, **kwargs):
451459
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
452460

453461
return hf_model
462+
463+
464+
def _generate_greedy_logprobs_limit(
465+
self,
466+
prompts: List[str],
467+
max_tokens: int,
468+
num_logprobs: int,
469+
images: Optional[PromptImageInput] = None,
470+
audios: Optional[PromptAudioInput] = None,
471+
videos: Optional[PromptVideoInput] = None,
472+
**kwargs: Any,
473+
) -> List[TokensTextLogprobs]:
474+
all_inputs = self.get_inputs(prompts,
475+
images=images,
476+
videos=videos,
477+
audios=audios)
478+
479+
# Process in batches for inference.
480+
if len(all_inputs):
481+
input_ids_lst = []
482+
images_lst = []
483+
images_input_idx_lst = []
484+
imges_masks_lst = []
485+
for inputs in all_inputs:
486+
input_ids_lst.append(inputs["input_ids"])
487+
images_lst.append(inputs["images"])
488+
images_input_idx_lst.append(inputs["image_input_idx"])
489+
imges_masks_lst.append(inputs["image_masks"])
490+
batch_inputs = {}
491+
batch_inputs['input_ids'] = torch.cat(input_ids_lst, dim=0)
492+
batch_inputs['images'] = torch.cat(images_lst, dim=0)
493+
batch_inputs['image_input_idx'] = torch.cat(images_input_idx_lst,
494+
dim=0)
495+
batch_inputs['image_masks'] = torch.cat(imges_masks_lst, dim=0)
496+
497+
outputs = self.model.generate_from_batch(
498+
batch=self.wrap_device(batch_inputs,
499+
device=self.model.device.type),
500+
generation_config=GenerationConfig(
501+
max_new_tokens=max_tokens,
502+
stop_strings="<|endoftext|>",
503+
do_sample=False,
504+
),
505+
tokenizer=self.tokenizer,
506+
output_hidden_states=True,
507+
return_dict_in_generate=True,
508+
)
509+
510+
all_logprobs: List[List[Dict[int, float]]] = []
511+
all_output_ids: List[List[int]] = []
512+
all_output_strs: List[str] = []
513+
514+
for index in range(len(all_inputs)):
515+
(
516+
seq_logprobs_lst,
517+
output_len,
518+
) = self._hidden_states_to_logprobs(outputs.hidden_states,
519+
num_logprobs)
520+
all_logprobs.append(seq_logprobs_lst)
521+
seq_ids = outputs.sequences[index]
522+
output_ids = seq_ids[-output_len:]
523+
all_output_ids.append(output_ids.tolist())
524+
all_output_strs.append(self.tokenizer.decode(output_ids))
525+
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
526+
return [(output_ids, output_str, output_logprobs)
527+
for output_ids, output_str, output_logprobs in outputs]
528+
529+
530+
####### Molmo-specific HuggingFace runner patchers
531+
def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
532+
"""Patches and returns an instance of the HfRunner to use for Molmo."""
533+
hf_processor = hf_model.processor
534+
535+
def _processor(*args, **kwargs):
536+
return hf_processor.process(*args, **kwargs)
537+
538+
hf_model.processor = _processor
539+
540+
setattr( # noqa: B010
541+
hf_model,
542+
"generate_greedy_logprobs_limit",
543+
types.MethodType(_generate_greedy_logprobs_limit, hf_model),
544+
)
545+
546+
return hf_model

vllm/model_executor/models/molmo.py

Lines changed: 17 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,45 +1081,25 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
10811081
else:
10821082
out = processor.process(None, image, tokens=inputs["prompt_token_ids"])
10831083

1084-
image_processor = processor.image_processor
1085-
max_total_crops = 1 + image_processor.max_crops
1086-
if image is not None:
1087-
images, image_input_idx, image_masks = pad_images(
1088-
max_total_crops,
1089-
out["images"],
1090-
out["image_input_idx"],
1091-
out.get("image_masks"),
1092-
)
1093-
else:
1094-
base_image_input_size = image_processor.base_image_input_size
1095-
image_patch_size = image_processor.image_patch_size
1096-
image_num_patch = (
1097-
base_image_input_size[0] // image_patch_size,
1098-
base_image_input_size[1] // image_patch_size,
1099-
)
1100-
n_pixels = image_patch_size * image_patch_size * 3
1101-
n_patches = image_num_patch[0] * image_num_patch[1]
1102-
1103-
image_length_w = image_processor.image_token_length_w
1104-
image_length_h = image_processor.image_token_length_h
1105-
tokens_per_image = image_length_w * image_length_h
1106-
images = torch.full(
1107-
(max_total_crops, n_patches, n_pixels),
1108-
-1,
1109-
dtype=torch.float32,
1110-
)
1111-
image_input_idx = torch.full(
1112-
(max_total_crops, tokens_per_image),
1113-
-1,
1114-
dtype=torch.int32,
1084+
# If there is no image, return directly.
1085+
if image is None:
1086+
new_prompt_token_ids = out["input_ids"].tolist()
1087+
prompt = inputs.get("prompt")
1088+
if prompt is None:
1089+
prompt = tokenizer.decode(new_prompt_token_ids)
1090+
return token_inputs(
1091+
prompt_token_ids=new_prompt_token_ids,
1092+
prompt=prompt,
11151093
)
1116-
if image_processor.image_padding_mask:
1117-
image_masks = torch.full(
1118-
(max_total_crops, n_patches),
1119-
-1,
1120-
dtype=torch.float32,
1121-
)
11221094

1095+
image_processor = processor.image_processor
1096+
max_total_crops = 1 + image_processor.max_crops
1097+
images, image_input_idx, image_masks = pad_images(
1098+
max_total_crops,
1099+
out["images"],
1100+
out["image_input_idx"],
1101+
out.get("image_masks"),
1102+
)
11231103
image_data = dict(
11241104
images=images,
11251105
image_input_idx=image_input_idx,
@@ -1143,11 +1123,9 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
11431123
offset = i
11441124
size += 1
11451125
image_data["image_start_end"] = (offset, offset + size)
1146-
11471126
prompt = inputs.get("prompt")
11481127
if prompt is None:
11491128
prompt = tokenizer.decode(new_prompt_token_ids)
1150-
11511129
return token_inputs(
11521130
prompt_token_ids=new_prompt_token_ids,
11531131
prompt=prompt,

0 commit comments

Comments
 (0)