Skip to content

Commit accb858

Browse files
DarkLight1337mzusman
authored andcommitted
[Doc] Basic guide for writing unit tests for new models (vllm-project#11951)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent d1cb412 commit accb858

File tree

6 files changed

+81
-3
lines changed

6 files changed

+81
-3
lines changed

docs/source/contributing/model/basic.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
(new-model-basic)=
22

3-
# Basic Implementation
3+
# Implementing a Basic Model
44

55
This guide walks you through the steps to implement a basic vLLM model.
66

docs/source/contributing/model/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ This section provides more information on how to integrate a [PyTorch](https://p
1010
1111
basic
1212
registration
13+
tests
1314
multimodal
1415
```
1516

docs/source/contributing/model/registration.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
(new-model-registration)=
22

3-
# Model Registration
3+
# Registering a Model to vLLM
44

55
vLLM relies on a model registry to determine how to run each model.
66
A list of pre-registered architectures can be found [here](#supported-models).
@@ -15,7 +15,6 @@ This gives you the ability to modify the codebase and test your model.
1515

1616
After you have implemented your model (see [tutorial](#new-model-basic)), put it into the <gh-dir:vllm/model_executor/models> directory.
1717
Then, add your model class to `_VLLM_MODELS` in <gh-file:vllm/model_executor/models/registry.py> so that it is automatically registered upon importing vLLM.
18-
You should also include an example HuggingFace repository for this model in <gh-file:tests/models/registry.py> to run the unit tests.
1918
Finally, update our [list of supported models](#supported-models) to promote your model!
2019

2120
```{important}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
(new-model-tests)=
2+
3+
# Writing Unit Tests
4+
5+
This page explains how to write unit tests to verify the implementation of your model.
6+
7+
## Required Tests
8+
9+
These tests are necessary to get your PR merged into vLLM library.
10+
Without them, the CI for your PR will fail.
11+
12+
### Model loading
13+
14+
Include an example HuggingFace repository for your model in <gh-file:tests/models/registry.py>.
15+
This enables a unit test that loads dummy weights to ensure that the model can be initialized in vLLM.
16+
17+
```{important}
18+
The list of models in each section should be maintained in alphabetical order.
19+
```
20+
21+
```{tip}
22+
If your model requires a development version of HF Transformers, you can set
23+
`min_transformers_version` to skip the test in CI until the model is released.
24+
```
25+
26+
## Optional Tests
27+
28+
These tests are optional to get your PR merged into vLLM library.
29+
Passing these tests provides more confidence that your implementation is correct, and helps avoid future regressions.
30+
31+
### Model correctness
32+
33+
These tests compare the model outputs of vLLM against [HF Transformers](https://github.com/huggingface/transformers). You can add new tests under the subdirectories of <gh-dir:tests/models>.
34+
35+
#### Generative models
36+
37+
For [generative models](#generative-models), there are two levels of correctness tests, as defined in <gh-file:tests/models/utils.py>:
38+
39+
- Exact correctness (`check_outputs_equal`): The text outputted by vLLM should exactly match the text outputted by HF.
40+
- Logprobs similarity (`check_logprobs_close`): The logprobs outputted by vLLM should be in the top-k logprobs outputted by HF, and vice versa.
41+
42+
#### Pooling models
43+
44+
For [pooling models](#pooling-models), we simply check the cosine similarity, as defined in <gh-file:tests/models/embedding/utils.py>.
45+
46+
(mm-processing-tests)=
47+
48+
### Multi-modal processing
49+
50+
#### Common tests
51+
52+
Adding your model to <gh-file:tests/models/multimodal/processing/test_common.py> verifies that the following input combinations result in the same outputs:
53+
54+
- Text + multi-modal data
55+
- Tokens + multi-modal data
56+
- Text + cached multi-modal data
57+
- Tokens + cached multi-modal data
58+
59+
#### Model-specific tests
60+
61+
You can add a new file under <gh-dir:tests/models/multimodal/processing> to run tests that only apply to your model.
62+
63+
For example, if the HF processor for your model accepts user-specified keyword arguments, you can verify that the keyword arguments are being applied correctly, such as in <gh-file:tests/models/multimodal/processing/test_phi3v.py>.

tests/models/registry.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ class _HfExamplesInfo:
2222
for speculative decoding.
2323
"""
2424

25+
min_transformers_version: Optional[str] = None
26+
"""
27+
The minimum version of HF Transformers that is required to run this model.
28+
"""
29+
2530
is_available_online: bool = True
2631
"""
2732
Set this to ``False`` if the name of this architecture no longer exists on

tests/models/test_initialization.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from unittest.mock import patch
22

33
import pytest
4+
from packaging.version import Version
45
from transformers import PretrainedConfig
6+
from transformers import __version__ as TRANSFORMERS_VERSION
57

68
from vllm import LLM
79

@@ -13,6 +15,14 @@ def test_can_initialize(model_arch):
1315
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
1416
if not model_info.is_available_online:
1517
pytest.skip("Model is not available online")
18+
if model_info.min_transformers_version is not None:
19+
current_version = TRANSFORMERS_VERSION
20+
required_version = model_info.min_transformers_version
21+
if Version(current_version) < Version(required_version):
22+
pytest.skip(
23+
f"You have `transformers=={current_version}` installed, but "
24+
f"`transformers>={required_version}` is required to run this "
25+
"model")
1626

1727
# Avoid OOM
1828
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:

0 commit comments

Comments
 (0)