From 4ab33c2d81866d4dd2f29df07f1a35491acbb39b Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Fri, 26 Jul 2024 10:16:06 +0500 Subject: [PATCH] Generation: stop at `eos` for assisted decoding (#31301) * fix * move changes to prompt lookup * add test * set eos in assistant model * style * fix flakiness * changes for new `main` * Update tests/generation/test_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/generation/test_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add comment to explain --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../generation/candidate_generator.py | 14 +++++++++ src/transformers/generation/utils.py | 2 +- tests/generation/test_utils.py | 29 +++++++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 39fa67bfaf..795d373a30 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -108,6 +108,9 @@ class AssistedCandidateGenerator(CandidateGenerator): self.assistant_model = assistant_model self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens + # Set eos in assistant same as in target model + self.assistant_model.generation_config.eos_token_id = generation_config.eos_token_id + # Prepare the kwargs for the assistant model assistant_kwargs = {} for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads @@ -267,6 +270,7 @@ class PromptLookupCandidateGenerator(CandidateGenerator): def __init__( self, + eos_token_id: torch.Tensor = None, num_output_tokens: int = 10, max_matching_ngram_size: int = None, max_length: int = 20, @@ -274,6 +278,7 @@ class PromptLookupCandidateGenerator(CandidateGenerator): self.num_output_tokens = num_output_tokens self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2 self.max_length = max_length + self.eos_token_id = eos_token_id if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0: raise ValueError("Invalid max_matching_ngram_size or num_output_tokens") @@ -319,6 +324,15 @@ class PromptLookupCandidateGenerator(CandidateGenerator): if start_idx < end_idx: chosen_ids = input_ids[0, start_idx:end_idx] match_found = True + + # remove remaining candidate ids if an "eos" token is found, otherwise the target model may + # accept eos and the rest as valid, thus not stopping generation after "eos" + # NOTE: below code is written based on the fact that assisted decoding supports only bs=1 + mask = torch.isin(chosen_ids, self.eos_token_id) + match_indices_eos = torch.nonzero(mask) + if match_indices_eos.numel() > 0: + first_eos_index = match_indices_eos[0].item() + chosen_ids = chosen_ids[:first_eos_index] break if match_found: break diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9d3a92d268..b7bfeaf40d 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -725,6 +725,7 @@ class GenerationMixin: """ if generation_config.prompt_lookup_num_tokens is not None: candidate_generator = PromptLookupCandidateGenerator( + eos_token_id=generation_config._eos_token_tensor, num_output_tokens=generation_config.prompt_lookup_num_tokens, max_matching_ngram_size=generation_config.max_matching_ngram_size, max_length=generation_config.max_length, @@ -3954,7 +3955,6 @@ class GenerationMixin: # 1. Fetch candidate sequences from a `CandidateGenerator` candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) - candidate_input_ids = candidate_input_ids.to(self.device) if candidate_logits is not None: candidate_logits = candidate_logits.to(self.device) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index b211836188..2c440bbd71 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -77,6 +77,7 @@ if is_torch_available(): MaxLengthCriteria, MinLengthLogitsProcessor, PhrasalConstraint, + PromptLookupCandidateGenerator, SampleDecoderOnlyOutput, SampleEncoderDecoderOutput, StoppingCriteria, @@ -1372,6 +1373,34 @@ class GenerationTesterMixin: self._check_outputs(output_assisted, input_ids, model.config, use_cache=True) + def test_prompt_lookup_decoding_stops_at_eos(self): + # This test ensures that the prompt lookup generation stops at eos token and does not suggest more tokens + # (see https://github.com/huggingface/transformers/pull/31301) + + # The main idea is to have an ngram (unigram in our case) that is repeated twice in the input ids. + # First time at the very end, so input ends with the unigrams, and second any arbitrary location. + # Also, we need an EOS token which will be injected just after the arbitrary located ngram. + # We verify that PLD will not copy and propose candidated that contain an EOS token, even if there are overlapping ngrams + # in input ids. Otherwise a proposed EOS along with the trailing (ngrams-1) tokens might be accepted by the target model. + # That seems as if the model "generated" and EOS but didn't stop from user's perspective + + input_ids = torch.randint(1, 50, (1, 10), device=torch_device) # generate inputs in range from 1-50 + arbitrary_ngram = 51 # this is the arbitrary ngram, specifically chosen OOV to prevent flaky tests + input_ids[:, 3] = arbitrary_ngram # set pre-eos to arbitrary_ngram which is for sure not present in inputs + input_ids[:, -1] = arbitrary_ngram # put arbitrary_ngram in the end for the necessary match to happen + + eos_token_id = torch.tensor([0], device=torch_device) + input_ids[:, 4] = eos_token_id # inject eos-token-id in input ids so that it is located after arbitrary_ngram + + # init cand geenerator with max_matching_ngram_size=1 to match per-token + candidate_generator = PromptLookupCandidateGenerator( + eos_token_id=eos_token_id, num_output_tokens=4, max_matching_ngram_size=1 + ) + output_prompt_lookup = candidate_generator.get_candidates(input_ids)[0] + + # PLD shouldn't propose any new tokens based on eos-match + self.assertTrue(output_prompt_lookup.shape[-1] == 10) + def test_generate_with_head_masking(self): """Test designed for encoder-decoder models to ensure the attention head masking is used.""" attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]