Skip to content

Commit a006d17

Browse files
committed
fix and speed up tests
Signed-off-by: Nick Hill <[email protected]>
1 parent e105532 commit a006d17

File tree

2 files changed

+67
-93
lines changed

2 files changed

+67
-93
lines changed

tests/v1/engine/test_llm_engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ def test_llm_engine_refuses_prompt_logprobs_with_apc(monkeypatch):
1212
prompt_logprobs enabled, which is incompatible."""
1313

1414
monkeypatch.setenv("VLLM_USE_V1", "1")
15+
# TODO(nick): Single-proc to work around a ZMQ shutdown hang for now.
16+
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
1517
with pytest.raises(ValueError) as excinfo:
1618
LLM(model="facebook/opt-125m", enable_prefix_caching=True).generate(
1719
"Hello, my name is",

tests/v1/sample/test_logprobs.py

Lines changed: 65 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,34 @@
1414

1515
from ...conftest import VllmRunner
1616

17-
MODELS = ["meta-llama/Llama-3.2-1B"]
17+
MODEL = "meta-llama/Llama-3.2-1B"
18+
DTYPE = "half"
19+
20+
21+
@pytest.fixture(scope="module")
22+
def vllm_model(vllm_runner):
23+
with vllm_runner(
24+
MODEL,
25+
dtype=DTYPE,
26+
max_logprobs=7,
27+
# Very small number of batched tokens to ensure
28+
# that we test chunking.
29+
max_num_batched_tokens=16,
30+
max_num_seqs=16,
31+
max_model_len=128,
32+
enforce_eager=True,
33+
#TODO: enable this once we support it for
34+
# prompt logprobs.
35+
enable_prefix_caching=False,
36+
gpu_memory_utilization=0.5,
37+
) as vllm_model:
38+
yield vllm_model
39+
40+
41+
@pytest.fixture(scope="module")
42+
def hf_model(hf_runner):
43+
with hf_runner(MODEL, dtype=DTYPE) as hf_model:
44+
yield hf_model
1845

1946

2047
def _repeat_logprob_config(
@@ -66,30 +93,23 @@ def _repeat_logprob_config(
6693

6794

6895
def _test_case_get_logprobs_and_prompt_logprobs(
69-
hf_runner,
70-
vllm_runner,
71-
model: str,
72-
dtype: str,
96+
hf_model,
97+
vllm_model,
7398
batch_logprobs_composition: str,
74-
max_num_batched_tokens: int,
7599
temperature: float,
76100
example_prompts,
77101
) -> None:
78102
test_prompts = example_prompts
79103

80-
max_num_seqs = 16
81-
max_model_len = 128
82-
83104
max_tokens = 5
84-
with hf_runner(model, dtype=dtype) as hf_model:
85-
hf_outputs = hf_model.generate_greedy(
86-
test_prompts,
87-
max_tokens=max_tokens,
88-
)
89-
hf_logprobs = hf_model.generate_greedy_logprobs(
90-
test_prompts,
91-
max_tokens=max_tokens,
92-
)
105+
hf_outputs = hf_model.generate_greedy(
106+
test_prompts,
107+
max_tokens=max_tokens,
108+
)
109+
hf_logprobs = hf_model.generate_greedy_logprobs(
110+
test_prompts,
111+
max_tokens=max_tokens,
112+
)
93113

94114
# Batch has mixed sample params
95115
# (different logprobs/prompt logprobs combos)
@@ -108,20 +128,8 @@ def _test_case_get_logprobs_and_prompt_logprobs(
108128
for num_lp, num_plp in logprob_prompt_logprob_list
109129
]
110130

111-
with vllm_runner(
112-
model,
113-
dtype=dtype,
114-
max_logprobs=7,
115-
max_num_batched_tokens=max_num_batched_tokens,
116-
max_num_seqs=max_num_seqs,
117-
max_model_len=max_model_len,
118-
enforce_eager=True,
119-
# TODO: enable this once we support it for
120-
# prompt logprobs.
121-
enable_prefix_caching=False,
122-
) as vllm_model:
123-
vllm_results = vllm_model.model.generate(
124-
test_prompts, sampling_params=vllm_sampling_params)
131+
vllm_results = vllm_model.model.generate(
132+
test_prompts, sampling_params=vllm_sampling_params)
125133

126134
for vllm_result, hf_logprob, hf_output, logprob_prompt_logprob in zip(
127135
vllm_results, hf_logprobs, hf_outputs,
@@ -260,21 +268,14 @@ def _test_case_get_logprobs_and_prompt_logprobs(
260268
assert vllm_result.prompt_logprobs is None
261269

262270

263-
@pytest.mark.parametrize("model", MODELS)
264-
@pytest.mark.parametrize("dtype",
265-
["half"]) # needed for comparing logprobs with HF
266-
# Include a very small max_num_batched_tokens to ensure we test chunking
267-
@pytest.mark.parametrize("max_num_batched_tokens", [16, 256])
271+
#@pytest.mark.skip_global_cleanup
268272
@pytest.mark.parametrize("batch_logprobs_composition",
269273
["NONE", "SAMPLE", "PROMPT", "SAMPLE_PROMPT"])
270274
@pytest.mark.parametrize("temperature", [0.0, 2.0])
271275
def test_get_logprobs_and_prompt_logprobs(
272-
hf_runner,
273-
vllm_runner,
274-
model: str,
275-
dtype: str,
276+
hf_model,
277+
vllm_model,
276278
batch_logprobs_composition: str,
277-
max_num_batched_tokens: int,
278279
temperature: float,
279280
example_prompts,
280281
) -> None:
@@ -293,22 +294,16 @@ def test_get_logprobs_and_prompt_logprobs(
293294
requests in the batch under test.
294295
295296
Args:
296-
hf_runner
297-
vllm_runner
298-
model
299-
dtype
297+
hf_model
298+
vllm_model
300299
batch_logprobs_composition: logprobs configuration for test batch
301-
max_num_batched_tokens: token budget for scheduling
302300
example_prompts
303301
monkeypatch
304302
"""
305303
_test_case_get_logprobs_and_prompt_logprobs(
306-
hf_runner=hf_runner,
307-
vllm_runner=vllm_runner,
308-
model=model,
309-
dtype=dtype,
304+
hf_model=hf_model,
305+
vllm_model=vllm_model,
310306
batch_logprobs_composition=batch_logprobs_composition,
311-
max_num_batched_tokens=max_num_batched_tokens,
312307
temperature=temperature,
313308
example_prompts=example_prompts)
314309

@@ -325,7 +320,8 @@ def test_max_logprobs(monkeypatch):
325320

326321
runner = VllmRunner("facebook/opt-125m",
327322
max_logprobs=1,
328-
enable_prefix_caching=False)
323+
enable_prefix_caching=False,
324+
max_model_len=256)
329325
vllm_sampling_params = SamplingParams(logprobs=1)
330326
# should pass
331327
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
@@ -335,35 +331,23 @@ def test_max_logprobs(monkeypatch):
335331
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
336332

337333

338-
@pytest.mark.parametrize("model", MODELS)
339-
def test_none_logprobs(vllm_runner, model, example_prompts, monkeypatch):
334+
def test_none_logprobs(vllm_model, example_prompts, monkeypatch):
340335
"""Engine should return `logprobs` and `prompt_logprobs` as `None`
341336
342337
Args:
343-
vllm_runner: vLLM engine runner fixture
344-
model: model name
338+
vllm_model: vLLM model fixture
345339
example_prompts: list of example prompts (test fixture)
346340
monkeypatch: supports editing env vars and rolling back changes
347341
after the test
348342
"""
349-
override_backend_env_variable(monkeypatch, "FLASH_ATTN")
350-
351-
max_num_seqs = 256
352-
max_num_batched_tokens = None
353343
max_tokens = 5
354344

355-
with vllm_runner(
356-
model,
357-
max_num_batched_tokens=max_num_batched_tokens,
358-
max_num_seqs=max_num_seqs,
359-
enable_prefix_caching=False,
360-
) as vllm_model:
361-
sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens,
362-
logprobs=None,
363-
prompt_logprobs=None,
364-
temperature=0.0)
365-
results_logprobs_none = vllm_model.model.generate(
366-
example_prompts, sampling_params=sampling_params_logprobs_none)
345+
sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens,
346+
logprobs=None,
347+
prompt_logprobs=None,
348+
temperature=0.0)
349+
results_logprobs_none = vllm_model.model.generate(
350+
example_prompts, sampling_params=sampling_params_logprobs_none)
367351

368352
for i in range(len(results_logprobs_none)):
369353
# Check sample logprobs are None
@@ -373,35 +357,23 @@ def test_none_logprobs(vllm_runner, model, example_prompts, monkeypatch):
373357
assert results_logprobs_none[i].prompt_logprobs is None
374358

375359

376-
@pytest.mark.parametrize("model", MODELS)
377-
def test_zero_logprobs(vllm_runner, model, example_prompts, monkeypatch):
360+
def test_zero_logprobs(vllm_model, example_prompts, monkeypatch):
378361
"""Engine should return sampled token and prompt token logprobs
379362
380363
Args:
381-
vllm_runner: vLLM engine runner fixture
382-
model: model name
364+
vllm_model: vLLM model fixture
383365
example_prompts: list of example prompts (test fixture)
384366
monkeypatch: supports editing env vars and rolling back changes
385367
after the test
386368
"""
387-
override_backend_env_variable(monkeypatch, "FLASH_ATTN")
388-
389-
max_num_seqs = 256
390-
max_num_batched_tokens = None
391369
max_tokens = 5
392370

393-
with vllm_runner(
394-
model,
395-
max_num_batched_tokens=max_num_batched_tokens,
396-
max_num_seqs=max_num_seqs,
397-
enable_prefix_caching=False,
398-
) as vllm_model:
399-
sampling_params_logprobs_zero = SamplingParams(max_tokens=max_tokens,
400-
logprobs=0,
401-
prompt_logprobs=0,
402-
temperature=0.0)
403-
results_logprobs_zero = vllm_model.model.generate(
404-
example_prompts, sampling_params=sampling_params_logprobs_zero)
371+
sampling_params_logprobs_zero = SamplingParams(max_tokens=max_tokens,
372+
logprobs=0,
373+
prompt_logprobs=0,
374+
temperature=0.0)
375+
results_logprobs_zero = vllm_model.model.generate(
376+
example_prompts, sampling_params=sampling_params_logprobs_zero)
405377

406378
for i in range(len(results_logprobs_zero)):
407379
# Check that there is one sample logprob dict for each

0 commit comments

Comments
 (0)