Skip to content

Commit b18e3ee

Browse files
HwwwwwwwHllsj14shaochangxuDarkLight1337NickLucche
committed
[Model] Refactoring of MiniCPM-V and add MiniCPM-o-2.6 support for vLLM (vllm-project#12069)
Signed-off-by: hzh <[email protected]> Signed-off-by: Sungjae Lee <[email protected]> Signed-off-by: shaochangxu.scx <[email protected]> Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: NickLucche <[email protected]> Signed-off-by: Isotr0py <[email protected]> Signed-off-by: Roger Wang <[email protected]> Signed-off-by: Rafael Vasquez <[email protected]> Signed-off-by: Akshat Tripathi <[email protected]> Signed-off-by: Oleg Mosalov <[email protected]> Signed-off-by: Jee Jee Li <[email protected]> Signed-off-by: [email protected] <[email protected]> Signed-off-by: Yida Wu <[email protected]> Signed-off-by: Chenguang Li <[email protected]> Signed-off-by: youkaichao <[email protected]> Signed-off-by: Alex-Brooks <[email protected]> Signed-off-by: Chen Zhang <[email protected]> Signed-off-by: Harry Mellor <[email protected]> Signed-off-by: Shanshan Shen <[email protected]> Signed-off-by: elijah <[email protected]> Signed-off-by: Yikun <[email protected]> Signed-off-by: mgoin <[email protected]> Signed-off-by: Woosuk Kwon <[email protected]> Signed-off-by: Konrad Zawora <[email protected]> Signed-off-by: tjtanaa <[email protected]> Signed-off-by: wangxiyuan <[email protected]> Signed-off-by: Rui Qiao <[email protected]> Co-authored-by: Sungjae Lee <[email protected]> Co-authored-by: shaochangxu <[email protected]> Co-authored-by: shaochangxu.scx <[email protected]> Co-authored-by: Cyrus Leung <[email protected]> Co-authored-by: Nicolò Lucchesi <[email protected]> Co-authored-by: sixgod <[email protected]> Co-authored-by: Isotr0py <[email protected]> Co-authored-by: Roger Wang <[email protected]> Co-authored-by: Rafael Vasquez <[email protected]> Co-authored-by: Isotr0py <[email protected]> Co-authored-by: Cyrus Leung <[email protected]> Co-authored-by: Akshat Tripathi <[email protected]> Co-authored-by: Oleg Mosalov <[email protected]> Co-authored-by: Jee Jee Li <[email protected]> Co-authored-by: Avshalom Manevich <[email protected]> Co-authored-by: Robert Shaw <[email protected]> Co-authored-by: Yangcheng Li <[email protected]> Co-authored-by: Siyuan Li <[email protected]> Co-authored-by: Concurrensee <[email protected]> Co-authored-by: Chenguang Li <[email protected]> Co-authored-by: youkaichao <[email protected]> Co-authored-by: Alex Brooks <[email protected]> Co-authored-by: Chen Zhang <[email protected]> Co-authored-by: Harry Mellor <[email protected]> Co-authored-by: Shanshan Shen <[email protected]> Co-authored-by: elijah <[email protected]> Co-authored-by: Yikun Jiang <[email protected]> Co-authored-by: Steve Luo <[email protected]> Co-authored-by: mgoin <[email protected]> Co-authored-by: Woosuk Kwon <[email protected]> Co-authored-by: Konrad Zawora <[email protected]> Co-authored-by: TJian <[email protected]> Co-authored-by: tjtanaa <[email protected]> Co-authored-by: wangxiyuan <[email protected]> Co-authored-by: maang-h <[email protected]> Co-authored-by: Elfie Guo <[email protected]> Co-authored-by: Rui Qiao <[email protected]> Co-authored-by: Roger Wang <[email protected]> Signed-off-by: Isotr0py <[email protected]>
1 parent db442cd commit b18e3ee

File tree

15 files changed

+1622
-186
lines changed

15 files changed

+1622
-186
lines changed

docs/source/models/supported_models.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,9 +693,16 @@ See [this page](#generative-models) for more information on how to use generativ
693693
*
694694
* ✅︎
695695
* ✅︎
696+
- * `MiniCPMO`
697+
* MiniCPM-O
698+
* T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>E+</sup>
699+
* `openbmb/MiniCPM-o-2_6`, etc.
700+
* ✅︎
701+
* ✅︎
702+
*
696703
- * `MiniCPMV`
697704
* MiniCPM-V
698-
* T + I<sup>E+</sup>
705+
* T + I<sup>E+</sup> + V<sup>E+</sup>
699706
* `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc.
700707
* ✅︎
701708
* ✅︎

examples/offline_inference/audio_language.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,37 @@ def run_qwen2_audio(question: str, audio_count: int):
6767
return llm, prompt, stop_token_ids
6868

6969

70-
model_example_map = {"ultravox": run_ultravox, "qwen2_audio": run_qwen2_audio}
70+
def run_minicpmo(question: str, audio_count: int):
71+
model_name = "openbmb/MiniCPM-o-2_6"
72+
tokenizer = AutoTokenizer.from_pretrained(model_name,
73+
trust_remote_code=True)
74+
llm = LLM(model=model_name,
75+
trust_remote_code=True,
76+
max_model_len=4096,
77+
max_num_seqs=5,
78+
limit_mm_per_prompt={"audio": audio_count})
79+
80+
stop_tokens = ['<|im_end|>', '<|endoftext|>']
81+
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
82+
83+
audio_placeholder = "(<audio>./</audio>)" * audio_count
84+
audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}" # noqa: E501
85+
messages = [{
86+
'role': 'user',
87+
'content': f'{audio_placeholder}\n{question}'
88+
}]
89+
prompt = tokenizer.apply_chat_template(messages,
90+
tokenize=False,
91+
add_generation_prompt=True,
92+
chat_template=audio_chat_template)
93+
return llm, prompt, stop_token_ids
94+
95+
96+
model_example_map = {
97+
"ultravox": run_ultravox,
98+
"qwen2_audio": run_qwen2_audio,
99+
"minicpmo": run_minicpmo
100+
}
71101

72102

73103
def main(args):

examples/offline_inference/vision_language.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,9 @@ def run_mantis(question: str, modality: str):
265265

266266

267267
# MiniCPM-V
268-
def run_minicpmv(question: str, modality: str):
269-
assert modality == "image"
268+
def run_minicpmv_base(question: str, modality: str, model_name):
269+
assert modality in ["image", "video"]
270+
# If you want to use `MiniCPM-o-2_6` with audio inputs, check `audio_language.py` # noqa
270271

271272
# 2.0
272273
# The official repo doesn't work yet, so we need to use a fork for now
@@ -277,7 +278,15 @@ def run_minicpmv(question: str, modality: str):
277278
# model_name = "openbmb/MiniCPM-Llama3-V-2_5"
278279

279280
# 2.6
280-
model_name = "openbmb/MiniCPM-V-2_6"
281+
# model_name = "openbmb/MiniCPM-V-2_6"
282+
# o2.6
283+
284+
# modality supports
285+
# 2.0: image
286+
# 2.5: image
287+
# 2.6: image, video
288+
# o2.6: image, video, audio
289+
# model_name = "openbmb/MiniCPM-o-2_6"
281290
tokenizer = AutoTokenizer.from_pretrained(model_name,
282291
trust_remote_code=True)
283292
llm = LLM(
@@ -294,20 +303,33 @@ def run_minicpmv(question: str, modality: str):
294303
# 2.5
295304
# stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
296305

297-
# 2.6
306+
# 2.6 / o2.6
298307
stop_tokens = ['<|im_end|>', '<|endoftext|>']
299308
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
300309

310+
modality_placeholder = {
311+
"image": "(<image>./</image>)",
312+
"video": "(<video>./</video>)",
313+
}
314+
301315
messages = [{
302316
'role': 'user',
303-
'content': f'(<image>./</image>)\n{question}'
317+
'content': f'{modality_placeholder[modality]}\n{question}'
304318
}]
305319
prompt = tokenizer.apply_chat_template(messages,
306320
tokenize=False,
307321
add_generation_prompt=True)
308322
return llm, prompt, stop_token_ids
309323

310324

325+
def run_minicpmo(question: str, modality: str):
326+
return run_minicpmv_base(question, modality, "openbmb/MiniCPM-o-2_6")
327+
328+
329+
def run_minicpmv(question: str, modality: str):
330+
return run_minicpmv_base(question, modality, "openbmb/MiniCPM-V-2_6")
331+
332+
311333
# LLama 3.2
312334
def run_mllama(question: str, modality: str):
313335
assert modality == "image"
@@ -523,6 +545,7 @@ def run_qwen2_vl(question: str, modality: str):
523545
"llava-next-video": run_llava_next_video,
524546
"llava-onevision": run_llava_onevision,
525547
"mantis": run_mantis,
548+
"minicpmo": run_minicpmo,
526549
"minicpmv": run_minicpmv,
527550
"mllama": run_mllama,
528551
"molmo": run_molmo,

requirements-cpu.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
# Dependencies for CPUs
55
torch==2.5.1+cpu; platform_machine != "ppc64le" and platform_machine != "aarch64" and platform_system != "Darwin"
66
torch==2.5.1; platform_machine == "aarch64" or platform_system == "Darwin"
7+
torchaudio; platform_machine != "ppc64le" # required for the image processor of minicpm-o-2_6, this must be updated alongside torch
78
torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
89
datasets # for benchmark scripts

requirements-cuda.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
ray[default] >= 2.9
66
nvidia-ml-py >= 12.560.30 # for pynvml package
77
torch == 2.5.1
8+
torchaudio==2.5.1
89
# These must be updated alongside torch
910
torchvision == 0.20.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
1011
xformers == 0.0.28.post3; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.5.1

requirements-test.in

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@ decord # required for video tests
1212
einops # required for MPT, qwen-vl and Mamba
1313
httpx
1414
librosa # required for audio tests
15+
vector_quantize_pytorch # required for minicpmo_26 test
16+
vocos # required for minicpmo_26 test
1517
peft
1618
pqdm
1719
ray[adag]==2.40.0
1820
sentence-transformers # required for embedding tests
1921
soundfile # required for audio tests
2022
timm # required for internvl test
2123
torch==2.5.1
24+
torchaudio==2.5.1
2225
transformers_stream_generator # required for qwen-vl test
2326
matplotlib # required for qwen-vl test
2427
mistral_common[opencv] >= 1.5.0 # required for pixtral test

requirements-test.txt

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,17 @@ dnspython==2.7.0
106106
docutils==0.16
107107
# via awscli
108108
einops==0.8.0
109-
# via -r requirements-test.in
109+
# via
110+
# -r requirements-test.in
111+
# encodec
112+
# vector-quantize-pytorch
113+
# vocos
114+
einx==0.3.0
115+
# via vector-quantize-pytorch
110116
email-validator==2.2.0
111117
# via pydantic
118+
encodec==0.1.1
119+
# via vocos
112120
evaluate==0.4.3
113121
# via lm-eval
114122
fastparquet==2024.11.0
@@ -125,6 +133,8 @@ filelock==3.16.1
125133
# triton
126134
fonttools==4.54.1
127135
# via matplotlib
136+
frozendict==2.4.6
137+
# via einx
128138
frozenlist==1.5.0
129139
# via
130140
# aiohttp
@@ -159,6 +169,7 @@ huggingface-hub==0.26.2
159169
# timm
160170
# tokenizers
161171
# transformers
172+
# vocos
162173
idna==3.10
163174
# via
164175
# anyio
@@ -261,6 +272,8 @@ numpy==1.26.4
261272
# cupy-cuda12x
262273
# datasets
263274
# decord
275+
# einx
276+
# encodec
264277
# evaluate
265278
# fastparquet
266279
# genai-perf
@@ -283,6 +296,7 @@ numpy==1.26.4
283296
# torchvision
284297
# transformers
285298
# tritonclient
299+
# vocos
286300
nvidia-cublas-cu12==12.4.5.8
287301
# via
288302
# nvidia-cudnn-cu12
@@ -455,6 +469,7 @@ pyyaml==6.0.2
455469
# responses
456470
# timm
457471
# transformers
472+
# vocos
458473
ray[adag]==2.40.0
459474
# via -r requirements-test.in
460475
redis==5.2.0
@@ -517,6 +532,7 @@ scipy==1.13.1
517532
# scikit-learn
518533
# sentence-transformers
519534
# statsmodels
535+
# vocos
520536
sentence-transformers==3.2.1
521537
# via -r requirements-test.in
522538
sentencepiece==0.2.0
@@ -540,7 +556,9 @@ sqlitedict==2.1.0
540556
statsmodels==0.14.4
541557
# via genai-perf
542558
sympy==1.13.1
543-
# via torch
559+
# via
560+
# einx
561+
# torch
544562
tabledata==1.3.3
545563
# via pytablewriter
546564
tabulate==0.9.0
@@ -568,12 +586,21 @@ torch==2.5.1
568586
# -r requirements-test.in
569587
# accelerate
570588
# bitsandbytes
589+
# encodec
571590
# lm-eval
572591
# peft
573592
# sentence-transformers
574593
# tensorizer
575594
# timm
595+
# torchaudio
576596
# torchvision
597+
# vector-quantize-pytorch
598+
# vocos
599+
torchaudio==2.5.1
600+
# via
601+
# -r requirements-test.in
602+
# encodec
603+
# vocos
577604
torchvision==0.20.1
578605
# via timm
579606
tqdm==4.66.6
@@ -584,6 +611,7 @@ tqdm==4.66.6
584611
# lm-eval
585612
# nltk
586613
# peft
614+
# pqdm
587615
# sentence-transformers
588616
# tqdm-multiprocess
589617
# transformers
@@ -615,6 +643,7 @@ typing-extensions==4.12.2
615643
# huggingface-hub
616644
# librosa
617645
# mistral-common
646+
# pqdm
618647
# pydantic
619648
# pydantic-core
620649
# torch
@@ -626,6 +655,10 @@ urllib3==2.2.3
626655
# requests
627656
# responses
628657
# tritonclient
658+
vector-quantize-pytorch==1.21.2
659+
# via -r requirements-test.in
660+
vocos==0.1.0
661+
# via -r requirements-test.in
629662
word2number==1.1
630663
# via lm-eval
631664
xxhash==3.5.0

tests/models/decoder_only/vision_language/test_models.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,20 @@
350350
postprocess_inputs=model_utils.wrap_inputs_post_processor,
351351
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
352352
),
353+
"minicpmo_26": VLMTestInfo(
354+
models=["openbmb/MiniCPM-o-2_6"],
355+
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
356+
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
357+
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
358+
max_model_len=4096,
359+
max_num_seqs=2,
360+
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
361+
postprocess_inputs=model_utils.ignore_inputs_post_processor(
362+
"image_sizes"
363+
),
364+
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
365+
patch_hf_runner=model_utils.minicpmo_patch_hf_runner
366+
),
353367
"minicpmv_26": VLMTestInfo(
354368
models=["openbmb/MiniCPM-V-2_6"],
355369
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),

tests/models/decoder_only/vision_language/vlm_utils/model_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,17 @@ def _generate(self, *args, **kwargs):
497497
return hf_model
498498

499499

500+
def minicpmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
501+
orig_generate = hf_model.model.generate
502+
503+
def _generate(self, *args, **kwargs):
504+
return orig_generate(*args, decode_text=False, **kwargs)
505+
506+
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
507+
508+
return hf_model
509+
510+
500511
def _generate_greedy_logprobs_limit(
501512
self,
502513
prompts: List[str],

tests/models/multimodal/processing/test_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def _test_processing_correctness(
152152
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
153153
"TIGER-Lab/Mantis-8B-siglip-llama3",
154154
"mistral-community/pixtral-12b",
155+
"openbmb/MiniCPM-o-2_6",
156+
"openbmb/MiniCPM-V-2_6",
155157
"Qwen/Qwen-VL-Chat",
156158
"Qwen/Qwen2-VL-2B-Instruct",
157159
"Qwen/Qwen2-Audio-7B-Instruct",

0 commit comments

Comments
 (0)