Skip to content

Conversation

yannicks1
Copy link
Contributor

@yannicks1 yannicks1 commented Sep 8, 2025

Re-enable prefill of max model length

Purpose

Previous to #20291 it was possible to do a prefill on the max model length and sample a single token. However, #20291 introduced an assertion:

start_idx = self.input_batch.num_tokens_no_spec[req_idx]
end_idx = start_idx + len(sampled_ids)
assert end_idx <= self.max_model_len <-----

which does not allow this anymore.

This PR allows the initial behavior again. Note that this initial behavior (allowing to do a prefill and requesting a single token on the max model length) is consistent with huggingface transformer behavior.

Test Plan

Can be tested with any decoder model with a prompt of length max_model_len and setting max_tokens=1.

Test Result

e.g before this PR max_model_len=2048 and max_tokens=1 resulted in this warning:

AssertionError: Sampled token IDs exceed the max model length. Total number of tokens: 2049 > max_model_len: 2048

after this PR: the assertion error does not show up and the correct output token is returned.

Signed-off-by: Yannick Schnider <[email protected]>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request relaxes an assertion to re-enable prefilling up to the maximum model length and sampling a single token. While the intent is correct, the change as-is will likely cause an IndexError because the underlying buffer for token IDs is not large enough to accommodate the extra token. A fix is required in vllm/v1/worker/gpu_input_batch.py (and likely vllm/v1/worker/tpu_input_batch.py) to increase the buffer size.

Comment on lines +1822 to +1825
assert end_idx <= self.max_model_len + 1, (
"Sampled token IDs exceed the max model length + 1. "
f"Total number of tokens: {end_idx} > max_model_len + 1: "
f"{self.max_model_len + 1}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

While this change correctly relaxes the assertion to allow for prefilling up to max_model_len and sampling one more token, it seems to introduce a potential IndexError in the subsequent line.

The buffer self.input_batch.token_ids_cpu is initialized with a size of (max_num_reqs, max_model_len) in vllm/v1/worker/gpu_input_batch.py.

When end_idx is self.max_model_len + 1, the slice assignment self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids will attempt to write to an out-of-bounds index (max_model_len).

To fix this, the token_ids_cpu buffer should probably be initialized with a size of (max_num_reqs, max_model_len + 1). This change would be required in vllm/v1/worker/gpu_input_batch.py. A similar change might be needed for vllm/v1/worker/tpu_input_batch.py as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if the change in vllm/v1/worker/tpu_input_batch.py is needed.

Signed-off-by: Yannick Schnider <[email protected]>
@mergify mergify bot added the tpu Related to Google TPUs label Sep 8, 2025
@yannicks1
Copy link
Contributor Author

@WoosukKwon @LucasWilkinson tagging you guys here as author/reviewer of #20291

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant