|
5 | 5 | import re
|
6 | 6 | import types
|
7 | 7 | from pathlib import PosixPath
|
8 |
| -from typing import Callable, List, Optional, Tuple, Union |
| 8 | +from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
9 | 9 |
|
10 | 10 | import torch
|
11 | 11 | from PIL.Image import Image
|
12 |
| -from transformers import AutoConfig, AutoTokenizer, BatchEncoding |
| 12 | +from transformers import (AutoConfig, AutoTokenizer, BatchEncoding, |
| 13 | + GenerationConfig) |
13 | 14 |
|
14 | 15 | from vllm.sequence import SampleLogprobs
|
15 | 16 | from vllm.transformers_utils.tokenizer import patch_padding_side
|
16 | 17 | from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
17 | 18 |
|
18 |
| -from .....conftest import HfRunner, ImageAsset, _ImageAssets |
| 19 | +from .....conftest import (HfRunner, ImageAsset, PromptAudioInput, |
| 20 | + PromptImageInput, PromptVideoInput, _ImageAssets) |
| 21 | +from ....utils import TokensTextLogprobs |
19 | 22 | from .types import RunnerOutput
|
20 | 23 |
|
21 | 24 |
|
@@ -222,6 +225,11 @@ def wrap_inputs_post_processor(hf_inputs: BatchEncoding, dtype: str):
|
222 | 225 | return {"model_inputs": hf_inputs}
|
223 | 226 |
|
224 | 227 |
|
| 228 | +def molmo_post_processor(hf_inputs: BatchEncoding, dtype: str): |
| 229 | + hf_inputs = cast_dtype_post_processor("images")(hf_inputs, dtype) |
| 230 | + return {k: v.unsqueeze(0) for k, v in hf_inputs.items()} |
| 231 | + |
| 232 | + |
225 | 233 | ####### Prompt path encoders for models that need models on disk
|
226 | 234 | def qwen_prompt_path_encoder(
|
227 | 235 | tmp_path: PosixPath, prompt: str, assets: Union[List[ImageAsset],
|
@@ -451,3 +459,88 @@ def _generate(self, *args, **kwargs):
|
451 | 459 | hf_model.model.generate = types.MethodType(_generate, hf_model.model)
|
452 | 460 |
|
453 | 461 | return hf_model
|
| 462 | + |
| 463 | + |
| 464 | +def _generate_greedy_logprobs_limit( |
| 465 | + self, |
| 466 | + prompts: List[str], |
| 467 | + max_tokens: int, |
| 468 | + num_logprobs: int, |
| 469 | + images: Optional[PromptImageInput] = None, |
| 470 | + audios: Optional[PromptAudioInput] = None, |
| 471 | + videos: Optional[PromptVideoInput] = None, |
| 472 | + **kwargs: Any, |
| 473 | +) -> List[TokensTextLogprobs]: |
| 474 | + all_inputs = self.get_inputs(prompts, |
| 475 | + images=images, |
| 476 | + videos=videos, |
| 477 | + audios=audios) |
| 478 | + |
| 479 | + # Process in batches for inference. |
| 480 | + if len(all_inputs): |
| 481 | + input_ids_lst = [] |
| 482 | + images_lst = [] |
| 483 | + images_input_idx_lst = [] |
| 484 | + imges_masks_lst = [] |
| 485 | + for inputs in all_inputs: |
| 486 | + input_ids_lst.append(inputs["input_ids"]) |
| 487 | + images_lst.append(inputs["images"]) |
| 488 | + images_input_idx_lst.append(inputs["image_input_idx"]) |
| 489 | + imges_masks_lst.append(inputs["image_masks"]) |
| 490 | + batch_inputs = {} |
| 491 | + batch_inputs['input_ids'] = torch.cat(input_ids_lst, dim=0) |
| 492 | + batch_inputs['images'] = torch.cat(images_lst, dim=0) |
| 493 | + batch_inputs['image_input_idx'] = torch.cat(images_input_idx_lst, |
| 494 | + dim=0) |
| 495 | + batch_inputs['image_masks'] = torch.cat(imges_masks_lst, dim=0) |
| 496 | + |
| 497 | + outputs = self.model.generate_from_batch( |
| 498 | + batch=self.wrap_device(batch_inputs, |
| 499 | + device=self.model.device.type), |
| 500 | + generation_config=GenerationConfig( |
| 501 | + max_new_tokens=max_tokens, |
| 502 | + stop_strings="<|endoftext|>", |
| 503 | + do_sample=False, |
| 504 | + ), |
| 505 | + tokenizer=self.tokenizer, |
| 506 | + output_hidden_states=True, |
| 507 | + return_dict_in_generate=True, |
| 508 | + ) |
| 509 | + |
| 510 | + all_logprobs: List[List[Dict[int, float]]] = [] |
| 511 | + all_output_ids: List[List[int]] = [] |
| 512 | + all_output_strs: List[str] = [] |
| 513 | + |
| 514 | + for index in range(len(all_inputs)): |
| 515 | + ( |
| 516 | + seq_logprobs_lst, |
| 517 | + output_len, |
| 518 | + ) = self._hidden_states_to_logprobs(outputs.hidden_states, |
| 519 | + num_logprobs) |
| 520 | + all_logprobs.append(seq_logprobs_lst) |
| 521 | + seq_ids = outputs.sequences[index] |
| 522 | + output_ids = seq_ids[-output_len:] |
| 523 | + all_output_ids.append(output_ids.tolist()) |
| 524 | + all_output_strs.append(self.tokenizer.decode(output_ids)) |
| 525 | + outputs = zip(all_output_ids, all_output_strs, all_logprobs) |
| 526 | + return [(output_ids, output_str, output_logprobs) |
| 527 | + for output_ids, output_str, output_logprobs in outputs] |
| 528 | + |
| 529 | + |
| 530 | +####### Molmo-specific HuggingFace runner patchers |
| 531 | +def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner: |
| 532 | + """Patches and returns an instance of the HfRunner to use for Molmo.""" |
| 533 | + hf_processor = hf_model.processor |
| 534 | + |
| 535 | + def _processor(*args, **kwargs): |
| 536 | + return hf_processor.process(*args, **kwargs) |
| 537 | + |
| 538 | + hf_model.processor = _processor |
| 539 | + |
| 540 | + setattr( # noqa: B010 |
| 541 | + hf_model, |
| 542 | + "generate_greedy_logprobs_limit", |
| 543 | + types.MethodType(_generate_greedy_logprobs_limit, hf_model), |
| 544 | + ) |
| 545 | + |
| 546 | + return hf_model |
0 commit comments