diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 782ac6db38fa..9f62e4dd0158 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -1004,26 +1004,39 @@ class PromptLookupCandidateGenerator(CandidateGenerator): Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding Args: - max_matching_ngram_size (`int`): - The maximum ngram size to be considered for matching in the prompt - num_output_tokens (`int`): + eos_token_id (`torch.Tensor`, *optional*): + The token id of the end of sequence token. + num_output_tokens (`int`, *optional*, defaults to 10): The number of tokens to be output as candidate tokens. - max_length (`int`): - The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length. - Defaults to 20, which is the max length used as default in generation config. + max_matching_ngram_size (`int`, *optional*, defaults to 2): + The maximum ngram size to be considered for matching in the prompt + max_length (`int`, *optional*, defaults to 20): + The number of total maximum tokens that can be generated. For decoder-only models that includes the + prompt length. Defaults to 20, which is the max length used as default in generation config. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. In + prompt lookup assisted generation, they are not used to manipulate probabilities, but rather to find + forbidden tokens (p = -inf) and block them from being valid candidates. + vocab_size (`int`, *optional*): + The size of the vocabulary. Required if `logits_processor` is provided. """ def __init__( self, eos_token_id: Optional[torch.Tensor] = None, num_output_tokens: int = 10, - max_matching_ngram_size: Optional[int] = None, + max_matching_ngram_size: int = 2, max_length: int = 20, + logits_processor: Optional["LogitsProcessorList"] = None, + vocab_size: Optional[int] = None, ): self.num_output_tokens = num_output_tokens - self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2 + self.max_matching_ngram_size = max_matching_ngram_size self.max_length = max_length self.eos_token_id = eos_token_id + self.logits_processor = logits_processor + self.vocab_size = vocab_size if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0: raise ValueError("Invalid max_matching_ngram_size or num_output_tokens") @@ -1039,7 +1052,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, Return: `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. """ - input_length = input_ids.size(1) + bsz, input_length = input_ids.shape # Don't generate more than `max_length - 1` candidates since the target model generates one extra token. if self.max_length == input_length + 1: @@ -1061,6 +1074,8 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, match_indices = matches.nonzero(as_tuple=True)[1] # Iterate through match indices to find a valid continuation + # TODO (joao): this finds the first valid candidates (left to right), but perhaps we should find the + # longest valid candidates? for idx in match_indices: start_idx = idx + ngram_size end_idx = start_idx + self.num_output_tokens @@ -1068,6 +1083,34 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, if start_idx < end_idx: chosen_ids = input_ids[0, start_idx:end_idx] + + # Check if the each new candidate token is forbidden according to the logits processor. If all + # tokens are allowed, we keep `chosen_ids` as is. + # 1. create random logits. + # 2. apply the logits processor to get output logits for the next token, using the arbitrary + # logits as input. + # 3. compare the output logits with the next candidate token. If they are -inf, then the next + # candidate token is forbidden and we don't want to generate it. + if self.logits_processor is not None: + sequence_with_candidate = input_ids + fake_input_logits = torch.ones( + (bsz, self.vocab_size), device=input_ids.device, dtype=torch.float32 + ) + for candidate_idx, new_candidate_token in enumerate(chosen_ids): + fake_output_logits = self.logits_processor(sequence_with_candidate, fake_input_logits) + fake_candidate_logits = fake_output_logits[0, new_candidate_token] + # next candidate token is forbidden -> crop chosen_ids accordingly + if fake_candidate_logits in (-float("Inf"), torch.finfo(fake_candidate_logits.dtype).min): + chosen_ids = chosen_ids[:candidate_idx] + break + else: + sequence_with_candidate = torch.cat( + (input_ids, chosen_ids[: candidate_idx + 1].unsqueeze(0)), dim=1 + ) + # no valid candidate tokens -> look for a different match + if chosen_ids.shape[0] == 0: + continue + match_found = True # remove remaining candidate ids if an "eos" token is found, otherwise the target model may @@ -1082,8 +1125,8 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, if match_found: break - if chosen_ids is None or len(chosen_ids) == 0: - # In case we didn't find a match return the input sequence unchanged, reverts back to autoregressive decoding + # In case we didn't find a match return the input sequence unchanged, reverts back to autoregressive decoding + if not match_found or len(chosen_ids) == 0: return input_ids, None # Now need extend input_ids with chosen_ids diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 227a84dd1d06..6ccee604f271 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1036,8 +1036,10 @@ def _get_candidate_generator( candidate_generator = PromptLookupCandidateGenerator( eos_token_id=generation_config._eos_token_tensor, num_output_tokens=generation_config.prompt_lookup_num_tokens, - max_matching_ngram_size=generation_config.max_matching_ngram_size, + max_matching_ngram_size=generation_config.max_matching_ngram_size or 2, max_length=generation_config.max_length, + logits_processor=logits_processor, + vocab_size=self.config.get_text_config().vocab_size, ) elif different_tokenizers: if generation_config.do_sample is True: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index b01733192cd9..1f7354b8ce80 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -779,27 +779,6 @@ def test_prompt_lookup_decoding_matches_greedy_search(self): "blip2", # overridden `generate()` for all BLIP models "instructblip", "instructblipvideo", - # TODO: The list is growing huge 🙃! Let's try to check if the config has any of audio/image/video token id and skip the test! - # All models below: shouldn't suggest image tokens. Can be fixed by passing `suppress_ids` to candidate generator: @joaa @raushan - "llava", - "idefics2", - "idefics3", - "mllama", - "paligemma", - "emu3", - "gotocr2", - "qwen2vl", - "qwen2_5_vl", - "ayavision", - "janus", - "gemma3", - "mistral3", - "chameleon", - "internvl", - "qwen2_5omni", # the file is named `qwen2_5_omni`, but the model class is `Qwen2_5Omni`, - # All models below: shouldn't suggest audio tokens. Can be fixed by passing `suppress_ids` to candidate generator: @joaa @raushan - "voxtral", - "qwen2audio", ] ): self.skipTest(reason="May fix in the future: need model-specific fixes") @@ -835,11 +814,12 @@ def test_prompt_lookup_decoding_matches_greedy_search(self): "return_dict_in_generate": True, "use_cache": True, } + logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config) - output_greedy = model.generate(**generation_kwargs, **inputs_dict) + output_greedy = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs) generation_kwargs.update({"prompt_lookup_num_tokens": 2}) # see b) - output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict) + output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs) # The two outputs must match and their shape must be as expected self.assertTrue(has_similar_generate_outputs(output_greedy, output_prompt_lookup)) diff --git a/tests/models/idefics2/test_modeling_idefics2.py b/tests/models/idefics2/test_modeling_idefics2.py index 199664a73d85..a500d8bf4946 100644 --- a/tests/models/idefics2/test_modeling_idefics2.py +++ b/tests/models/idefics2/test_modeling_idefics2.py @@ -390,12 +390,6 @@ def test_flash_attn_2_generate_padding_right(self): def test_flash_attn_2_inference_padding_right(self): pass - @unittest.skip( - reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates" - ) - def test_prompt_lookup_decoding_matches_greedy_search(self): - pass - @pytest.mark.generate @slow @unittest.skip( diff --git a/tests/models/idefics3/test_modeling_idefics3.py b/tests/models/idefics3/test_modeling_idefics3.py index 97cff53643bc..b4434f34b81c 100644 --- a/tests/models/idefics3/test_modeling_idefics3.py +++ b/tests/models/idefics3/test_modeling_idefics3.py @@ -351,12 +351,6 @@ def test_inputs_embeds(): def test_flash_attn_2_inference_padding_right(self): pass - @unittest.skip( - reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates" - ) - def test_prompt_lookup_decoding_matches_greedy_search(self): - pass - @pytest.mark.generate @slow @unittest.skip( diff --git a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py index a8917f72928e..a105302a9952 100644 --- a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py +++ b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py @@ -30,7 +30,6 @@ from transformers.testing_utils import ( Expectations, cleanup, - is_flaky, require_cv2, require_flash_attn, require_torch, @@ -446,10 +445,6 @@ def test_multi_gpu_data_parallel_forward(self): def test_model_is_small(self): pass - @is_flaky() # TODO (joao/raushan): Investigate why this test is flaky on this model - def test_prompt_lookup_decoding_matches_greedy_search(self): - super().test_prompt_lookup_decoding_matches_greedy_search() - @require_torch class Qwen2_5_VLIntegrationTest(unittest.TestCase): diff --git a/tests/models/smolvlm/test_modeling_smolvlm.py b/tests/models/smolvlm/test_modeling_smolvlm.py index e6e38eb10534..6a3c8c5fa346 100644 --- a/tests/models/smolvlm/test_modeling_smolvlm.py +++ b/tests/models/smolvlm/test_modeling_smolvlm.py @@ -345,12 +345,6 @@ def setUp(self): def test_flash_attn_2_inference_padding_right(self): pass - @unittest.skip( - reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates" - ) - def test_prompt_lookup_decoding_matches_greedy_search(self): - pass - @pytest.mark.generate @is_flaky(description="TODO: check why flaky") def test_generate_methods_with_logits_to_keep(self):