From f2090b8ecc14829942dca65bbbde8fe352e1289b Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 5 Sep 2025 15:39:56 +0000 Subject: [PATCH 1/7] no longer flaky :) --- .../generation/candidate_generator.py | 15 +++++++++-- src/transformers/generation/utils.py | 1 + tests/generation/test_utils.py | 26 +++---------------- .../models/idefics2/test_modeling_idefics2.py | 6 ----- .../models/idefics3/test_modeling_idefics3.py | 6 ----- .../qwen2_5_vl/test_modeling_qwen2_5_vl.py | 5 ---- tests/models/smolvlm/test_modeling_smolvlm.py | 6 ----- 7 files changed, 17 insertions(+), 48 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 38a15f3dae9c..bde4635ac7f5 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -1019,11 +1019,15 @@ def __init__( num_output_tokens: int = 10, max_matching_ngram_size: Optional[int] = None, max_length: int = 20, + bad_words_ids: Optional[list[list[int]]] = None, ): self.num_output_tokens = num_output_tokens self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2 self.max_length = max_length self.eos_token_id = eos_token_id + self.bad_words_ids = ( + [torch.tensor(bad_word_id) for bad_word_id in bad_words_ids] if bad_words_ids is not None else None + ) if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0: raise ValueError("Invalid max_matching_ngram_size or num_output_tokens") @@ -1068,6 +1072,13 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, if start_idx < end_idx: chosen_ids = input_ids[0, start_idx:end_idx] + + # If `chosen_ids` contains any of the sequences in `bad_words_ids`, skips this match -- this + # sequence could never be generated in normal circumstances. + if self.bad_words_ids is not None: + if any(bad_word_id.to(chosen_ids.device) in chosen_ids for bad_word_id in self.bad_words_ids): + continue + match_found = True # remove remaining candidate ids if an "eos" token is found, otherwise the target model may @@ -1082,8 +1093,8 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, if match_found: break - if chosen_ids is None or len(chosen_ids) == 0: - # In case we didn't find a match return the input sequence unchanged, reverts back to autoregressive decoding + # In case we didn't find a match return the input sequence unchanged, reverts back to autoregressive decoding + if not match_found or len(chosen_ids) == 0: return input_ids, None # Now need extend input_ids with chosen_ids diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 2716c79c702f..eb59f8706b0f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1025,6 +1025,7 @@ def _get_candidate_generator( num_output_tokens=generation_config.prompt_lookup_num_tokens, max_matching_ngram_size=generation_config.max_matching_ngram_size, max_length=generation_config.max_length, + bad_words_ids=generation_config.bad_words_ids, ) elif different_tokenizers: if generation_config.do_sample is True: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 85841c557df1..f814711011e5 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -779,27 +779,6 @@ def test_prompt_lookup_decoding_matches_greedy_search(self): "blip2", # overridden `generate()` for all BLIP models "instructblip", "instructblipvideo", - # TODO: The list is growing huge 🙃! Let's try to check if the config has any of audio/image/video token id and skip the test! - # All models below: shouldn't suggest image tokens. Can be fixed by passing `suppress_ids` to candidate generator: @joaa @raushan - "llava", - "idefics2", - "idefics3", - "mllama", - "paligemma", - "emu3", - "gotocr2", - "qwen2vl", - "qwen2_5_vl", - "ayavision", - "janus", - "gemma3", - "mistral3", - "chameleon", - "internvl", - "qwen2_5omni", # the file is named `qwen2_5_omni`, but the model class is `Qwen2_5Omni`, - # All models below: shouldn't suggest audio tokens. Can be fixed by passing `suppress_ids` to candidate generator: @joaa @raushan - "voxtral", - "qwen2audio", ] ): self.skipTest(reason="May fix in the future: need model-specific fixes") @@ -835,11 +814,12 @@ def test_prompt_lookup_decoding_matches_greedy_search(self): "return_dict_in_generate": True, "use_cache": True, } + logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config) - output_greedy = model.generate(**generation_kwargs, **inputs_dict) + output_greedy = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs) generation_kwargs.update({"prompt_lookup_num_tokens": 2}) # see b) - output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict) + output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs) # The two outputs must match and their shape must be as expected self.assertTrue(has_similar_generate_outputs(output_greedy, output_prompt_lookup)) diff --git a/tests/models/idefics2/test_modeling_idefics2.py b/tests/models/idefics2/test_modeling_idefics2.py index 199664a73d85..a500d8bf4946 100644 --- a/tests/models/idefics2/test_modeling_idefics2.py +++ b/tests/models/idefics2/test_modeling_idefics2.py @@ -390,12 +390,6 @@ def test_flash_attn_2_generate_padding_right(self): def test_flash_attn_2_inference_padding_right(self): pass - @unittest.skip( - reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates" - ) - def test_prompt_lookup_decoding_matches_greedy_search(self): - pass - @pytest.mark.generate @slow @unittest.skip( diff --git a/tests/models/idefics3/test_modeling_idefics3.py b/tests/models/idefics3/test_modeling_idefics3.py index 97cff53643bc..b4434f34b81c 100644 --- a/tests/models/idefics3/test_modeling_idefics3.py +++ b/tests/models/idefics3/test_modeling_idefics3.py @@ -351,12 +351,6 @@ def test_inputs_embeds(): def test_flash_attn_2_inference_padding_right(self): pass - @unittest.skip( - reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates" - ) - def test_prompt_lookup_decoding_matches_greedy_search(self): - pass - @pytest.mark.generate @slow @unittest.skip( diff --git a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py index a8917f72928e..a105302a9952 100644 --- a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py +++ b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py @@ -30,7 +30,6 @@ from transformers.testing_utils import ( Expectations, cleanup, - is_flaky, require_cv2, require_flash_attn, require_torch, @@ -446,10 +445,6 @@ def test_multi_gpu_data_parallel_forward(self): def test_model_is_small(self): pass - @is_flaky() # TODO (joao/raushan): Investigate why this test is flaky on this model - def test_prompt_lookup_decoding_matches_greedy_search(self): - super().test_prompt_lookup_decoding_matches_greedy_search() - @require_torch class Qwen2_5_VLIntegrationTest(unittest.TestCase): diff --git a/tests/models/smolvlm/test_modeling_smolvlm.py b/tests/models/smolvlm/test_modeling_smolvlm.py index e6e38eb10534..6a3c8c5fa346 100644 --- a/tests/models/smolvlm/test_modeling_smolvlm.py +++ b/tests/models/smolvlm/test_modeling_smolvlm.py @@ -345,12 +345,6 @@ def setUp(self): def test_flash_attn_2_inference_padding_right(self): pass - @unittest.skip( - reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates" - ) - def test_prompt_lookup_decoding_matches_greedy_search(self): - pass - @pytest.mark.generate @is_flaky(description="TODO: check why flaky") def test_generate_methods_with_logits_to_keep(self): From 46cc6d51f952e59e7ae445c3ea3e29c3a9aeeb0e Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 9 Sep 2025 09:18:05 +0000 Subject: [PATCH 2/7] PR comments --- src/transformers/generation/candidate_generator.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index bde4635ac7f5..5bcd34a4fff1 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -1009,8 +1009,11 @@ class PromptLookupCandidateGenerator(CandidateGenerator): num_output_tokens (`int`): The number of tokens to be output as candidate tokens. max_length (`int`): - The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length. - Defaults to 20, which is the max length used as default in generation config. + The number of total maximum tokens that can be generated. For decoder-only models that includes the + prompt length. Defaults to 20, which is the max length used as default in generation config. + bad_words_ids (`list[list[int]]`, *optional*): + List of sequences of tokens that are not allowed to be generated. In other words, if if prompt lookup + contains a match of any of these sequences, the match is skipped. """ def __init__( @@ -1021,6 +1024,8 @@ def __init__( max_length: int = 20, bad_words_ids: Optional[list[list[int]]] = None, ): + # TODO (joao): at call time (get_candidates), check if there are other logits processors that set + # probabilities to `-inf` and block those tokens self.num_output_tokens = num_output_tokens self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2 self.max_length = max_length From 1f72346a81432a98530d60b281091722a069328e Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 9 Sep 2025 10:14:07 +0000 Subject: [PATCH 3/7] any token-blocking logits processor works --- .../generation/candidate_generator.py | 63 +++++++++++++------ src/transformers/generation/utils.py | 5 +- 2 files changed, 47 insertions(+), 21 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 5bcd34a4fff1..c267ed642d55 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -1004,35 +1004,39 @@ class PromptLookupCandidateGenerator(CandidateGenerator): Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding Args: - max_matching_ngram_size (`int`): + eos_token_id (`torch.Tensor`, *optional*): + The token id of the end of sequence token. + max_matching_ngram_size (`int`, *optional*, defaults to 2): The maximum ngram size to be considered for matching in the prompt - num_output_tokens (`int`): + num_output_tokens (`int`, *optional*): The number of tokens to be output as candidate tokens. - max_length (`int`): + max_length (`int`, *optional*, defaults to 20): The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length. Defaults to 20, which is the max length used as default in generation config. - bad_words_ids (`list[list[int]]`, *optional*): - List of sequences of tokens that are not allowed to be generated. In other words, if if prompt lookup - contains a match of any of these sequences, the match is skipped. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. In + prompt lookup assisted generation, they are not used to manipulate probabilities, but rather to find + forbidden tokens (p = -inf) and block them from being valid candidates. + vocab_size (`int`, *optional*): + The size of the vocabulary. Required if `logits_processor` is provided. """ def __init__( self, eos_token_id: Optional[torch.Tensor] = None, - num_output_tokens: int = 10, + num_output_tokens: Optional[int] = None, max_matching_ngram_size: Optional[int] = None, - max_length: int = 20, - bad_words_ids: Optional[list[list[int]]] = None, + max_length: Optional[int] = None, + logits_processor: Optional["LogitsProcessorList"] = None, + vocab_size: Optional[int] = None, ): - # TODO (joao): at call time (get_candidates), check if there are other logits processors that set - # probabilities to `-inf` and block those tokens self.num_output_tokens = num_output_tokens self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2 self.max_length = max_length self.eos_token_id = eos_token_id - self.bad_words_ids = ( - [torch.tensor(bad_word_id) for bad_word_id in bad_words_ids] if bad_words_ids is not None else None - ) + self.logits_processor = logits_processor + self.vocab_size = vocab_size if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0: raise ValueError("Invalid max_matching_ngram_size or num_output_tokens") @@ -1048,7 +1052,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, Return: `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. """ - input_length = input_ids.size(1) + bsz, input_length = input_ids.shape # Don't generate more than `max_length - 1` candidates since the target model generates one extra token. if self.max_length == input_length + 1: @@ -1070,6 +1074,8 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, match_indices = matches.nonzero(as_tuple=True)[1] # Iterate through match indices to find a valid continuation + # TODO (joao): this finds the first valid candidates (left to right), but perhaps we should find the + # longest valid candidates? for idx in match_indices: start_idx = idx + ngram_size end_idx = start_idx + self.num_output_tokens @@ -1078,10 +1084,29 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, if start_idx < end_idx: chosen_ids = input_ids[0, start_idx:end_idx] - # If `chosen_ids` contains any of the sequences in `bad_words_ids`, skips this match -- this - # sequence could never be generated in normal circumstances. - if self.bad_words_ids is not None: - if any(bad_word_id.to(chosen_ids.device) in chosen_ids for bad_word_id in self.bad_words_ids): + # Check if the each new candidate token is forbidden according to the logits processor. If all + # tokens are allowed, we keep `chosen_ids` as is. + # 1. create random logits. + # 2. apply the logits processor to get output logits for the next token, using the arbitrary + # logits as input. + # 3. compare the output logits with the next candidate token. If they are -inf, then the next + # candidate token is forbidden and we don't want to generate it. + if self.logits_processor is not None: + sequence_with_candidate = input_ids + for candidate_idx, new_candidate_token in enumerate(chosen_ids): + fake_input_logits = torch.ones((bsz, self.vocab_size), device=input_ids.device) + fake_output_logits = self.logits_processor(sequence_with_candidate, fake_input_logits) + fake_candidate_logits = fake_output_logits[0, new_candidate_token] + # next candidate token is forbidden -> crop chosen_ids accordingly + if fake_candidate_logits in (-float("Inf"), torch.finfo(fake_candidate_logits.dtype).min): + chosen_ids = chosen_ids[:candidate_idx] + break + else: + sequence_with_candidate = torch.cat( + (input_ids, chosen_ids[: candidate_idx + 1].unsqueeze(0)), dim=1 + ) + # no valid candidate tokens -> look for a different match + if chosen_ids.shape[0] == 0: continue match_found = True diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index eb59f8706b0f..ab8b6c0ea093 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1025,7 +1025,8 @@ def _get_candidate_generator( num_output_tokens=generation_config.prompt_lookup_num_tokens, max_matching_ngram_size=generation_config.max_matching_ngram_size, max_length=generation_config.max_length, - bad_words_ids=generation_config.bad_words_ids, + logits_processor=logits_processor, + vocab_size=self.config.get_text_config().vocab_size, ) elif different_tokenizers: if generation_config.do_sample is True: @@ -2548,7 +2549,7 @@ def generate( input_ids=input_ids, inputs_tensor=inputs_tensor, assistant_model=generation_mode_kwargs.pop("assistant_model", None), - logits_processor=logits_processor, + logits_processor=prepared_logits_processor, target_tokenizer=generation_mode_kwargs.pop("tokenizer", None), assistant_tokenizer=generation_mode_kwargs.pop("assistant_tokenizer", None), model_kwargs=model_kwargs, From e87e2750a3f810556f6c1d343741861f73a133af Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 9 Sep 2025 10:31:46 +0000 Subject: [PATCH 4/7] ? --- src/transformers/generation/candidate_generator.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index f1b1366498f9..bd131ff3a1b4 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -1006,10 +1006,10 @@ class PromptLookupCandidateGenerator(CandidateGenerator): Args: eos_token_id (`torch.Tensor`, *optional*): The token id of the end of sequence token. + num_output_tokens (`int`, *optional*, defaults to 10): + The number of tokens to be output as candidate tokens. max_matching_ngram_size (`int`, *optional*, defaults to 2): The maximum ngram size to be considered for matching in the prompt - num_output_tokens (`int`, *optional*): - The number of tokens to be output as candidate tokens. max_length (`int`, *optional*, defaults to 20): The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length. Defaults to 20, which is the max length used as default in generation config. @@ -1025,9 +1025,9 @@ class PromptLookupCandidateGenerator(CandidateGenerator): def __init__( self, eos_token_id: Optional[torch.Tensor] = None, - num_output_tokens: Optional[int] = None, + num_output_tokens: int = 10, max_matching_ngram_size: Optional[int] = None, - max_length: Optional[int] = None, + max_length: int = 20, logits_processor: Optional["LogitsProcessorList"] = None, vocab_size: Optional[int] = None, ): From d57ee3182190dc2667f8a4e7adb723ccb414f587 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 9 Sep 2025 10:32:53 +0000 Subject: [PATCH 5/7] default --- src/transformers/generation/candidate_generator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index bd131ff3a1b4..e8dd6961c018 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -1026,13 +1026,13 @@ def __init__( self, eos_token_id: Optional[torch.Tensor] = None, num_output_tokens: int = 10, - max_matching_ngram_size: Optional[int] = None, + max_matching_ngram_size: int = 2, max_length: int = 20, logits_processor: Optional["LogitsProcessorList"] = None, vocab_size: Optional[int] = None, ): self.num_output_tokens = num_output_tokens - self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2 + self.max_matching_ngram_size = max_matching_ngram_size self.max_length = max_length self.eos_token_id = eos_token_id self.logits_processor = logits_processor From f3453cc6df6f40a6e50e8c5048c5efbc07236195 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 9 Sep 2025 10:43:22 +0000 Subject: [PATCH 6/7] -_- --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e5101640239e..6ccee604f271 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1036,7 +1036,7 @@ def _get_candidate_generator( candidate_generator = PromptLookupCandidateGenerator( eos_token_id=generation_config._eos_token_tensor, num_output_tokens=generation_config.prompt_lookup_num_tokens, - max_matching_ngram_size=generation_config.max_matching_ngram_size, + max_matching_ngram_size=generation_config.max_matching_ngram_size or 2, max_length=generation_config.max_length, logits_processor=logits_processor, vocab_size=self.config.get_text_config().vocab_size, From 45892571c7df886e06361b4fd3ed0a1f7dc056fa Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 9 Sep 2025 10:54:03 +0000 Subject: [PATCH 7/7] create fake tensors once --- src/transformers/generation/candidate_generator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index e8dd6961c018..9f62e4dd0158 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -1093,8 +1093,10 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, # candidate token is forbidden and we don't want to generate it. if self.logits_processor is not None: sequence_with_candidate = input_ids + fake_input_logits = torch.ones( + (bsz, self.vocab_size), device=input_ids.device, dtype=torch.float32 + ) for candidate_idx, new_candidate_token in enumerate(chosen_ids): - fake_input_logits = torch.ones((bsz, self.vocab_size), device=input_ids.device) fake_output_logits = self.logits_processor(sequence_with_candidate, fake_input_logits) fake_candidate_logits = fake_output_logits[0, new_candidate_token] # next candidate token is forbidden -> crop chosen_ids accordingly