14
14
15
15
from ...conftest import VllmRunner
16
16
17
- MODELS = ["meta-llama/Llama-3.2-1B" ]
17
+ MODEL = "meta-llama/Llama-3.2-1B"
18
+ DTYPE = "half"
19
+
20
+
21
+ @pytest .fixture (scope = "module" )
22
+ def vllm_model (vllm_runner ):
23
+ with vllm_runner (
24
+ MODEL ,
25
+ dtype = DTYPE ,
26
+ max_logprobs = 7 ,
27
+ # Very small number of batched tokens to ensure
28
+ # that we test chunking.
29
+ max_num_batched_tokens = 16 ,
30
+ max_num_seqs = 16 ,
31
+ max_model_len = 128 ,
32
+ enforce_eager = True ,
33
+ #TODO: enable this once we support it for
34
+ # prompt logprobs.
35
+ enable_prefix_caching = False ,
36
+ gpu_memory_utilization = 0.5 ,
37
+ ) as vllm_model :
38
+ yield vllm_model
39
+
40
+
41
+ @pytest .fixture (scope = "module" )
42
+ def hf_model (hf_runner ):
43
+ with hf_runner (MODEL , dtype = DTYPE ) as hf_model :
44
+ yield hf_model
18
45
19
46
20
47
def _repeat_logprob_config (
@@ -66,30 +93,23 @@ def _repeat_logprob_config(
66
93
67
94
68
95
def _test_case_get_logprobs_and_prompt_logprobs (
69
- hf_runner ,
70
- vllm_runner ,
71
- model : str ,
72
- dtype : str ,
96
+ hf_model ,
97
+ vllm_model ,
73
98
batch_logprobs_composition : str ,
74
- max_num_batched_tokens : int ,
75
99
temperature : float ,
76
100
example_prompts ,
77
101
) -> None :
78
102
test_prompts = example_prompts
79
103
80
- max_num_seqs = 16
81
- max_model_len = 128
82
-
83
104
max_tokens = 5
84
- with hf_runner (model , dtype = dtype ) as hf_model :
85
- hf_outputs = hf_model .generate_greedy (
86
- test_prompts ,
87
- max_tokens = max_tokens ,
88
- )
89
- hf_logprobs = hf_model .generate_greedy_logprobs (
90
- test_prompts ,
91
- max_tokens = max_tokens ,
92
- )
105
+ hf_outputs = hf_model .generate_greedy (
106
+ test_prompts ,
107
+ max_tokens = max_tokens ,
108
+ )
109
+ hf_logprobs = hf_model .generate_greedy_logprobs (
110
+ test_prompts ,
111
+ max_tokens = max_tokens ,
112
+ )
93
113
94
114
# Batch has mixed sample params
95
115
# (different logprobs/prompt logprobs combos)
@@ -108,20 +128,8 @@ def _test_case_get_logprobs_and_prompt_logprobs(
108
128
for num_lp , num_plp in logprob_prompt_logprob_list
109
129
]
110
130
111
- with vllm_runner (
112
- model ,
113
- dtype = dtype ,
114
- max_logprobs = 7 ,
115
- max_num_batched_tokens = max_num_batched_tokens ,
116
- max_num_seqs = max_num_seqs ,
117
- max_model_len = max_model_len ,
118
- enforce_eager = True ,
119
- # TODO: enable this once we support it for
120
- # prompt logprobs.
121
- enable_prefix_caching = False ,
122
- ) as vllm_model :
123
- vllm_results = vllm_model .model .generate (
124
- test_prompts , sampling_params = vllm_sampling_params )
131
+ vllm_results = vllm_model .model .generate (
132
+ test_prompts , sampling_params = vllm_sampling_params )
125
133
126
134
for vllm_result , hf_logprob , hf_output , logprob_prompt_logprob in zip (
127
135
vllm_results , hf_logprobs , hf_outputs ,
@@ -260,21 +268,14 @@ def _test_case_get_logprobs_and_prompt_logprobs(
260
268
assert vllm_result .prompt_logprobs is None
261
269
262
270
263
- @pytest .mark .parametrize ("model" , MODELS )
264
- @pytest .mark .parametrize ("dtype" ,
265
- ["half" ]) # needed for comparing logprobs with HF
266
- # Include a very small max_num_batched_tokens to ensure we test chunking
267
- @pytest .mark .parametrize ("max_num_batched_tokens" , [16 , 256 ])
271
+ #@pytest.mark.skip_global_cleanup
268
272
@pytest .mark .parametrize ("batch_logprobs_composition" ,
269
273
["NONE" , "SAMPLE" , "PROMPT" , "SAMPLE_PROMPT" ])
270
274
@pytest .mark .parametrize ("temperature" , [0.0 , 2.0 ])
271
275
def test_get_logprobs_and_prompt_logprobs (
272
- hf_runner ,
273
- vllm_runner ,
274
- model : str ,
275
- dtype : str ,
276
+ hf_model ,
277
+ vllm_model ,
276
278
batch_logprobs_composition : str ,
277
- max_num_batched_tokens : int ,
278
279
temperature : float ,
279
280
example_prompts ,
280
281
) -> None :
@@ -293,22 +294,16 @@ def test_get_logprobs_and_prompt_logprobs(
293
294
requests in the batch under test.
294
295
295
296
Args:
296
- hf_runner
297
- vllm_runner
298
- model
299
- dtype
297
+ hf_model
298
+ vllm_model
300
299
batch_logprobs_composition: logprobs configuration for test batch
301
- max_num_batched_tokens: token budget for scheduling
302
300
example_prompts
303
301
monkeypatch
304
302
"""
305
303
_test_case_get_logprobs_and_prompt_logprobs (
306
- hf_runner = hf_runner ,
307
- vllm_runner = vllm_runner ,
308
- model = model ,
309
- dtype = dtype ,
304
+ hf_model = hf_model ,
305
+ vllm_model = vllm_model ,
310
306
batch_logprobs_composition = batch_logprobs_composition ,
311
- max_num_batched_tokens = max_num_batched_tokens ,
312
307
temperature = temperature ,
313
308
example_prompts = example_prompts )
314
309
@@ -325,7 +320,8 @@ def test_max_logprobs(monkeypatch):
325
320
326
321
runner = VllmRunner ("facebook/opt-125m" ,
327
322
max_logprobs = 1 ,
328
- enable_prefix_caching = False )
323
+ enable_prefix_caching = False ,
324
+ max_model_len = 256 )
329
325
vllm_sampling_params = SamplingParams (logprobs = 1 )
330
326
# should pass
331
327
runner .generate (["Hello world" ], sampling_params = vllm_sampling_params )
@@ -335,35 +331,23 @@ def test_max_logprobs(monkeypatch):
335
331
runner .generate (["Hello world" ], sampling_params = bad_sampling_params )
336
332
337
333
338
- @pytest .mark .parametrize ("model" , MODELS )
339
- def test_none_logprobs (vllm_runner , model , example_prompts , monkeypatch ):
334
+ def test_none_logprobs (vllm_model , example_prompts , monkeypatch ):
340
335
"""Engine should return `logprobs` and `prompt_logprobs` as `None`
341
336
342
337
Args:
343
- vllm_runner: vLLM engine runner fixture
344
- model: model name
338
+ vllm_model: vLLM model fixture
345
339
example_prompts: list of example prompts (test fixture)
346
340
monkeypatch: supports editing env vars and rolling back changes
347
341
after the test
348
342
"""
349
- override_backend_env_variable (monkeypatch , "FLASH_ATTN" )
350
-
351
- max_num_seqs = 256
352
- max_num_batched_tokens = None
353
343
max_tokens = 5
354
344
355
- with vllm_runner (
356
- model ,
357
- max_num_batched_tokens = max_num_batched_tokens ,
358
- max_num_seqs = max_num_seqs ,
359
- enable_prefix_caching = False ,
360
- ) as vllm_model :
361
- sampling_params_logprobs_none = SamplingParams (max_tokens = max_tokens ,
362
- logprobs = None ,
363
- prompt_logprobs = None ,
364
- temperature = 0.0 )
365
- results_logprobs_none = vllm_model .model .generate (
366
- example_prompts , sampling_params = sampling_params_logprobs_none )
345
+ sampling_params_logprobs_none = SamplingParams (max_tokens = max_tokens ,
346
+ logprobs = None ,
347
+ prompt_logprobs = None ,
348
+ temperature = 0.0 )
349
+ results_logprobs_none = vllm_model .model .generate (
350
+ example_prompts , sampling_params = sampling_params_logprobs_none )
367
351
368
352
for i in range (len (results_logprobs_none )):
369
353
# Check sample logprobs are None
@@ -373,35 +357,23 @@ def test_none_logprobs(vllm_runner, model, example_prompts, monkeypatch):
373
357
assert results_logprobs_none [i ].prompt_logprobs is None
374
358
375
359
376
- @pytest .mark .parametrize ("model" , MODELS )
377
- def test_zero_logprobs (vllm_runner , model , example_prompts , monkeypatch ):
360
+ def test_zero_logprobs (vllm_model , example_prompts , monkeypatch ):
378
361
"""Engine should return sampled token and prompt token logprobs
379
362
380
363
Args:
381
- vllm_runner: vLLM engine runner fixture
382
- model: model name
364
+ vllm_model: vLLM model fixture
383
365
example_prompts: list of example prompts (test fixture)
384
366
monkeypatch: supports editing env vars and rolling back changes
385
367
after the test
386
368
"""
387
- override_backend_env_variable (monkeypatch , "FLASH_ATTN" )
388
-
389
- max_num_seqs = 256
390
- max_num_batched_tokens = None
391
369
max_tokens = 5
392
370
393
- with vllm_runner (
394
- model ,
395
- max_num_batched_tokens = max_num_batched_tokens ,
396
- max_num_seqs = max_num_seqs ,
397
- enable_prefix_caching = False ,
398
- ) as vllm_model :
399
- sampling_params_logprobs_zero = SamplingParams (max_tokens = max_tokens ,
400
- logprobs = 0 ,
401
- prompt_logprobs = 0 ,
402
- temperature = 0.0 )
403
- results_logprobs_zero = vllm_model .model .generate (
404
- example_prompts , sampling_params = sampling_params_logprobs_zero )
371
+ sampling_params_logprobs_zero = SamplingParams (max_tokens = max_tokens ,
372
+ logprobs = 0 ,
373
+ prompt_logprobs = 0 ,
374
+ temperature = 0.0 )
375
+ results_logprobs_zero = vllm_model .model .generate (
376
+ example_prompts , sampling_params = sampling_params_logprobs_zero )
405
377
406
378
for i in range (len (results_logprobs_zero )):
407
379
# Check that there is one sample logprob dict for each
0 commit comments