Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 54 additions & 11 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,26 +1004,39 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding

Args:
max_matching_ngram_size (`int`):
The maximum ngram size to be considered for matching in the prompt
num_output_tokens (`int`):
eos_token_id (`torch.Tensor`, *optional*):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(docstring args were out of order, and some were missing)

The token id of the end of sequence token.
num_output_tokens (`int`, *optional*, defaults to 10):
The number of tokens to be output as candidate tokens.
max_length (`int`):
The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length.
Defaults to 20, which is the max length used as default in generation config.
max_matching_ngram_size (`int`, *optional*, defaults to 2):
The maximum ngram size to be considered for matching in the prompt
max_length (`int`, *optional*, defaults to 20):
The number of total maximum tokens that can be generated. For decoder-only models that includes the
prompt length. Defaults to 20, which is the max length used as default in generation config.
logits_processor (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step. In
prompt lookup assisted generation, they are not used to manipulate probabilities, but rather to find
forbidden tokens (p = -inf) and block them from being valid candidates.
vocab_size (`int`, *optional*):
The size of the vocabulary. Required if `logits_processor` is provided.
"""

def __init__(
self,
eos_token_id: Optional[torch.Tensor] = None,
num_output_tokens: int = 10,
max_matching_ngram_size: Optional[int] = None,
max_matching_ngram_size: int = 2,
max_length: int = 20,
logits_processor: Optional["LogitsProcessorList"] = None,
vocab_size: Optional[int] = None,
):
self.num_output_tokens = num_output_tokens
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
self.max_matching_ngram_size = max_matching_ngram_size
self.max_length = max_length
self.eos_token_id = eos_token_id
self.logits_processor = logits_processor
self.vocab_size = vocab_size

if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
Expand All @@ -1039,7 +1052,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor,
Return:
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
"""
input_length = input_ids.size(1)
bsz, input_length = input_ids.shape

# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
if self.max_length == input_length + 1:
Expand All @@ -1061,13 +1074,43 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor,
match_indices = matches.nonzero(as_tuple=True)[1]

# Iterate through match indices to find a valid continuation
# TODO (joao): this finds the first valid candidates (left to right), but perhaps we should find the
# longest valid candidates?
for idx in match_indices:
start_idx = idx + ngram_size
end_idx = start_idx + self.num_output_tokens
end_idx = min(end_idx, input_length, self.max_length)

if start_idx < end_idx:
chosen_ids = input_ids[0, start_idx:end_idx]

# Check if the each new candidate token is forbidden according to the logits processor. If all
# tokens are allowed, we keep `chosen_ids` as is.
# 1. create random logits.
# 2. apply the logits processor to get output logits for the next token, using the arbitrary
# logits as input.
# 3. compare the output logits with the next candidate token. If they are -inf, then the next
# candidate token is forbidden and we don't want to generate it.
if self.logits_processor is not None:
sequence_with_candidate = input_ids
fake_input_logits = torch.ones(
(bsz, self.vocab_size), device=input_ids.device, dtype=torch.float32
)
for candidate_idx, new_candidate_token in enumerate(chosen_ids):
fake_output_logits = self.logits_processor(sequence_with_candidate, fake_input_logits)
fake_candidate_logits = fake_output_logits[0, new_candidate_token]
# next candidate token is forbidden -> crop chosen_ids accordingly
if fake_candidate_logits in (-float("Inf"), torch.finfo(fake_candidate_logits.dtype).min):
chosen_ids = chosen_ids[:candidate_idx]
break
else:
sequence_with_candidate = torch.cat(
(input_ids, chosen_ids[: candidate_idx + 1].unsqueeze(0)), dim=1
)
# no valid candidate tokens -> look for a different match
if chosen_ids.shape[0] == 0:
continue

match_found = True

# remove remaining candidate ids if an "eos" token is found, otherwise the target model may
Expand All @@ -1082,8 +1125,8 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor,
if match_found:
break

if chosen_ids is None or len(chosen_ids) == 0:
# In case we didn't find a match return the input sequence unchanged, reverts back to autoregressive decoding
# In case we didn't find a match return the input sequence unchanged, reverts back to autoregressive decoding
if not match_found or len(chosen_ids) == 0:
return input_ids, None

# Now need extend input_ids with chosen_ids
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,8 +1036,10 @@ def _get_candidate_generator(
candidate_generator = PromptLookupCandidateGenerator(
eos_token_id=generation_config._eos_token_tensor,
num_output_tokens=generation_config.prompt_lookup_num_tokens,
max_matching_ngram_size=generation_config.max_matching_ngram_size,
max_matching_ngram_size=generation_config.max_matching_ngram_size or 2,
max_length=generation_config.max_length,
logits_processor=logits_processor,
vocab_size=self.config.get_text_config().vocab_size,
)
elif different_tokenizers:
if generation_config.do_sample is True:
Expand Down
26 changes: 3 additions & 23 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,27 +779,6 @@ def test_prompt_lookup_decoding_matches_greedy_search(self):
"blip2", # overridden `generate()` for all BLIP models
"instructblip",
"instructblipvideo",
# TODO: The list is growing huge 🙃! Let's try to check if the config has any of audio/image/video token id and skip the test!
# All models below: shouldn't suggest image tokens. Can be fixed by passing `suppress_ids` to candidate generator: @joaa @raushan
"llava",
"idefics2",
"idefics3",
"mllama",
"paligemma",
"emu3",
"gotocr2",
"qwen2vl",
"qwen2_5_vl",
"ayavision",
"janus",
"gemma3",
"mistral3",
"chameleon",
"internvl",
"qwen2_5omni", # the file is named `qwen2_5_omni`, but the model class is `Qwen2_5Omni`,
# All models below: shouldn't suggest audio tokens. Can be fixed by passing `suppress_ids` to candidate generator: @joaa @raushan
"voxtral",
"qwen2audio",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yaaaaaay 💃🏻

]
):
self.skipTest(reason="May fix in the future: need model-specific fixes")
Expand Down Expand Up @@ -835,11 +814,12 @@ def test_prompt_lookup_decoding_matches_greedy_search(self):
"return_dict_in_generate": True,
"use_cache": True,
}
logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config)

output_greedy = model.generate(**generation_kwargs, **inputs_dict)
output_greedy = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs)

generation_kwargs.update({"prompt_lookup_num_tokens": 2}) # see b)
output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict)
output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs)

# The two outputs must match and their shape must be as expected
self.assertTrue(has_similar_generate_outputs(output_greedy, output_prompt_lookup))
Expand Down
6 changes: 0 additions & 6 deletions tests/models/idefics2/test_modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,12 +390,6 @@ def test_flash_attn_2_generate_padding_right(self):
def test_flash_attn_2_inference_padding_right(self):
pass

@unittest.skip(
reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates"
)
def test_prompt_lookup_decoding_matches_greedy_search(self):
pass

@pytest.mark.generate
@slow
@unittest.skip(
Expand Down
6 changes: 0 additions & 6 deletions tests/models/idefics3/test_modeling_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,6 @@ def test_inputs_embeds():
def test_flash_attn_2_inference_padding_right(self):
pass

@unittest.skip(
reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates"
)
def test_prompt_lookup_decoding_matches_greedy_search(self):
pass

@pytest.mark.generate
@slow
@unittest.skip(
Expand Down
5 changes: 0 additions & 5 deletions tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from transformers.testing_utils import (
Expectations,
cleanup,
is_flaky,
require_cv2,
require_flash_attn,
require_torch,
Expand Down Expand Up @@ -446,10 +445,6 @@ def test_multi_gpu_data_parallel_forward(self):
def test_model_is_small(self):
pass

@is_flaky() # TODO (joao/raushan): Investigate why this test is flaky on this model
def test_prompt_lookup_decoding_matches_greedy_search(self):
super().test_prompt_lookup_decoding_matches_greedy_search()


@require_torch
class Qwen2_5_VLIntegrationTest(unittest.TestCase):
Expand Down
6 changes: 0 additions & 6 deletions tests/models/smolvlm/test_modeling_smolvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,12 +345,6 @@ def setUp(self):
def test_flash_attn_2_inference_padding_right(self):
pass

@unittest.skip(
reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates"
)
def test_prompt_lookup_decoding_matches_greedy_search(self):
pass

@pytest.mark.generate
@is_flaky(description="TODO: check why flaky")
def test_generate_methods_with_logits_to_keep(self):
Expand Down