9
9
from ...utils import check_outputs_equal
10
10
11
11
# This test is for the hybrid models
12
- MODELS = ["ai21labs/Jamba-tiny-dev" , "Zyphra/Zamba2-1.2B-instruct" ]
12
+ MODELS = [
13
+ "ai21labs/Jamba-tiny-dev" , "Zyphra/Zamba2-1.2B-instruct" ,
14
+ "pfnet/plamo-2-1b"
15
+ ]
13
16
# Bamba at Fp32 is too big for the CI (L4 GPU).
14
17
# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]
18
+ # Note: Running Plamo2 in transformers implementation requires to install
19
+ # causal-conv1d package, which is not listed as a test dependency as it's
20
+ # not compatible with pip-compile.
15
21
16
22
17
23
@pytest .mark .parametrize ("model" , MODELS )
@@ -25,21 +31,11 @@ def test_models(
25
31
dtype : str ,
26
32
max_tokens : int ,
27
33
) -> None :
28
-
29
34
# numeric error produces different generation
30
35
if "Bamba" in model :
31
36
example_prompts .pop (3 )
32
37
33
- model_kwargs = {
34
- "use_mamba_kernels" : False , # mamba kernels are not installed so HF
35
- # don't use them
36
- }
37
- if "Zamba2" in model :
38
- # Zamba2 HF implementation automatically checks if mamba kernels are
39
- # installed
40
- model_kwargs = {}
41
-
42
- with hf_runner (model , dtype = dtype , model_kwargs = model_kwargs ) as hf_model :
38
+ with hf_runner (model , dtype = dtype ) as hf_model :
43
39
hf_outputs = hf_model .generate_greedy (example_prompts , max_tokens )
44
40
45
41
with vllm_runner (model , dtype = dtype ) as vllm_model :
@@ -94,6 +90,10 @@ def test_mamba_prefill_chunking_with_parallel_sampling(
94
90
# correctly for n > 1 decoding steps inside a
95
91
# chunked prefill forward pass (where we have both prefills
96
92
# and decoding together )
93
+
94
+ if 'plamo-2' in model :
95
+ dtype = "float" # use a different dtype for plamo
96
+
97
97
sampling_params = SamplingParams (n = 3 ,
98
98
temperature = 1 ,
99
99
seed = 0 ,
@@ -125,20 +125,14 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
125
125
example_prompts .pop (3 )
126
126
example_prompts .pop (2 )
127
127
dtype = "half" # use a different dtype for Bamba
128
+
128
129
elif "Zamba2" in model :
129
130
example_prompts .pop (7 )
130
131
dtype = "half"
132
+ elif "plamo-2-1b" in model :
133
+ example_prompts .pop (7 )
131
134
132
- model_kwargs = {
133
- "use_mamba_kernels" : False , # mamba kernels are not installed so HF
134
- # don't use them
135
- }
136
- if "Zamba2" in model :
137
- # Zamba2 HF implementation automatically checks if mamba kernels are
138
- # installed
139
- model_kwargs = {}
140
-
141
- with hf_runner (model , dtype = dtype , model_kwargs = model_kwargs ) as hf_model :
135
+ with hf_runner (model , dtype = dtype ) as hf_model :
142
136
non_chunked = hf_model .generate_greedy (example_prompts , max_tokens )
143
137
144
138
with vllm_runner (model ,
@@ -208,7 +202,8 @@ def test_mamba_cache_cg_padding(
208
202
# This test is for verifying that mamba cache is padded to CG captured
209
203
# batch size. If it's not, a torch RuntimeError will be raised because
210
204
# tensor dimensions aren't compatible
211
- vllm_config = EngineArgs (model = model ).create_engine_config ()
205
+ vllm_config = EngineArgs (model = model ,
206
+ trust_remote_code = True ).create_engine_config ()
212
207
while len (example_prompts ) == vllm_config .pad_for_cudagraph (
213
208
len (example_prompts )):
214
209
example_prompts .append (example_prompts [0 ])
0 commit comments