Skip to content

Commit 6c834dd

Browse files
WoosukKwonzhewenl
authored andcommitted
[BugFix][Spec Decode] Use float64 for uniform_probs (vllm-project#23803)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 544ac6e commit 6c834dd

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

examples/offline_inference/spec_decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def main():
138138
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
139139
if not args.custom_mm_prompts:
140140
outputs = llm.generate(
141-
TokensPrompt(prompt_token_ids=prompt_ids),
141+
[TokensPrompt(prompt_token_ids=x) for x in prompt_ids],
142142
sampling_params=sampling_params,
143143
)
144144
else:

vllm/v1/sample/rejection_sampler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,14 @@ def generate_uniform_probs(
365365
A tensor of shape `(num_tokens, )` containing uniform
366366
random values in the range [0, 1).
367367
"""
368+
# NOTE(woosuk): We deliberately use float64 instead of float32 here
369+
# because when using float32, there's a non-negligible chance that
370+
# uniform_prob is sampled to be exact 0.0 as reported in
371+
# https://github.com/pytorch/pytorch/issues/16706. Using float64
372+
# mitigates the issue.
368373
uniform_probs = torch.rand(
369374
(num_tokens, ),
370-
dtype=torch.float32,
375+
dtype=torch.float64,
371376
device=device,
372377
)
373378
start_idx = 0

0 commit comments

Comments
 (0)