Skip to content

Commit b3d7e3c

Browse files
charlotte12l22quinnDarkLight1337
authored
[Sampler] Support returning all prompt logprobs (#23868)
Signed-off-by: Xingyu Liu <[email protected]> Co-authored-by: 22quinn <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent 6784131 commit b3d7e3c

File tree

4 files changed

+38
-18
lines changed

4 files changed

+38
-18
lines changed

tests/v1/sample/test_logprobs.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def test_zero_logprobs(vllm_model, example_prompts,
430430

431431

432432
def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
433-
"""Engine should return all vocabulary logprobs
433+
"""Engine should return all vocabulary logprobs and prompt logprobs
434434
435435
Args:
436436
example_prompts: list of example prompts (test fixture)
@@ -444,16 +444,24 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
444444
# 2 other llms alive during whole session
445445
gpu_memory_utilization=0.15,
446446
max_model_len=256)
447+
447448
sampling_params_logprobs_all = SamplingParams(max_tokens=5,
448-
logprobs=-1)
449+
logprobs=-1,
450+
prompt_logprobs=-1)
449451
results_logprobs_all = runner.llm.generate(
450452
example_prompts, sampling_params=sampling_params_logprobs_all)
451453
vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size()
454+
452455
for i in range(len(results_logprobs_all)):
453456
logprobs = results_logprobs_all[i].outputs[0].logprobs
457+
prompt_logprobs = results_logprobs_all[i].prompt_logprobs
454458
assert logprobs is not None
455459
for logprob in logprobs:
456460
assert len(logprob) == vocab_size
461+
assert prompt_logprobs is not None
462+
assert prompt_logprobs[0] is None
463+
for prompt_logprob in prompt_logprobs[1:]:
464+
assert len(prompt_logprob) == vocab_size
457465

458466

459467
@pytest.mark.parametrize("logprobs_mode", list(LogprobsMode))

vllm/sampling_params.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,8 @@ class SamplingParams(
165165
the sampled token, so there may be up to `logprobs+1` elements in the
166166
response. When set to -1, return all `vocab_size` log probabilities."""
167167
prompt_logprobs: Optional[int] = None
168-
"""Number of log probabilities to return per prompt token."""
168+
"""Number of log probabilities to return per prompt token.
169+
When set to -1, return all `vocab_size` log probabilities."""
169170
# NOTE: This parameter is only exposed at the engine level for now.
170171
# It is not exposed in the OpenAI API server, as the OpenAI API does
171172
# not support returning only a list of token IDs.
@@ -409,9 +410,11 @@ def _verify_args(self) -> None:
409410
and self.logprobs < 0):
410411
raise ValueError(
411412
f"logprobs must be non-negative or -1, got {self.logprobs}.")
412-
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
413-
raise ValueError(f"prompt_logprobs must be non-negative, got "
414-
f"{self.prompt_logprobs}.")
413+
if (self.prompt_logprobs is not None and self.prompt_logprobs != -1
414+
and self.prompt_logprobs < 0):
415+
raise ValueError(
416+
f"prompt_logprobs must be non-negative or -1, got "
417+
f"{self.prompt_logprobs}.")
415418
if (self.truncate_prompt_tokens is not None
416419
and (self.truncate_prompt_tokens == 0
417420
or self.truncate_prompt_tokens < -1)):

vllm/v1/engine/processor.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,19 +65,27 @@ def _validate_logprobs(
6565
) -> None:
6666
max_logprobs = self.model_config.max_logprobs
6767
if max_logprobs == -1:
68-
return
68+
max_logprobs = self.model_config.get_vocab_size()
69+
6970
# Validate sample logprobs.
70-
if params.logprobs and (params.logprobs == -1
71-
or params.logprobs > max_logprobs):
72-
raise ValueError(
73-
f"Requested sample logprobs of {params.logprobs}, "
74-
f"which is greater than max allowed: {max_logprobs}")
71+
if params.logprobs:
72+
num_logprobs = params.logprobs
73+
if num_logprobs == -1:
74+
num_logprobs = self.model_config.get_vocab_size()
75+
if num_logprobs > max_logprobs:
76+
raise ValueError(
77+
f"Requested sample logprobs of {num_logprobs}, "
78+
f"which is is greater than max allowed: {max_logprobs}")
7579

7680
# Validate prompt logprobs.
77-
if params.prompt_logprobs and params.prompt_logprobs > max_logprobs:
78-
raise ValueError(
79-
f"Requested prompt logprobs of {params.prompt_logprobs}, "
80-
f"which is greater than max allowed: {max_logprobs}")
81+
if params.prompt_logprobs:
82+
num_prompt_logprobs = params.prompt_logprobs
83+
if num_prompt_logprobs == -1:
84+
num_prompt_logprobs = self.model_config.get_vocab_size()
85+
if num_prompt_logprobs > max_logprobs:
86+
raise ValueError(
87+
f"Requested prompt logprobs of {num_prompt_logprobs}, "
88+
f"which is is greater than max allowed: {max_logprobs}")
8189

8290
def _validate_sampling_params(
8391
self,

vllm/v1/worker/gpu_input_batch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,9 @@ def add_request(
360360
if sampling_params.logprobs == -1
361361
else sampling_params.logprobs)
362362
if sampling_params.prompt_logprobs is not None:
363-
self.num_prompt_logprobs[
364-
req_id] = sampling_params.prompt_logprobs
363+
self.num_prompt_logprobs[req_id] = (
364+
self.vocab_size if sampling_params.prompt_logprobs == -1
365+
else sampling_params.prompt_logprobs)
365366

366367
if sampling_params.allowed_token_ids:
367368
self.has_allowed_token_ids.add(req_id)

0 commit comments

Comments
 (0)