Generation: stop at eos for assisted decoding (#31301)

* fix

* move changes to prompt lookup

* add test

* set eos in assistant model

* style

* fix flakiness

* changes for new `main`

* Update tests/generation/test_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/generation/test_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add comment to explain

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2024-07-26 10:16:06 +05:00
committed by GitHub
parent 9d6c0641c4
commit 4ab33c2d81
3 changed files with 44 additions and 1 deletions

View File

@@ -77,6 +77,7 @@ if is_torch_available():
MaxLengthCriteria,
MinLengthLogitsProcessor,
PhrasalConstraint,
PromptLookupCandidateGenerator,
SampleDecoderOnlyOutput,
SampleEncoderDecoderOutput,
StoppingCriteria,
@@ -1372,6 +1373,34 @@ class GenerationTesterMixin:
self._check_outputs(output_assisted, input_ids, model.config, use_cache=True)
def test_prompt_lookup_decoding_stops_at_eos(self):
# This test ensures that the prompt lookup generation stops at eos token and does not suggest more tokens
# (see https://github.com/huggingface/transformers/pull/31301)
# The main idea is to have an ngram (unigram in our case) that is repeated twice in the input ids.
# First time at the very end, so input ends with the unigrams, and second any arbitrary location.
# Also, we need an EOS token which will be injected just after the arbitrary located ngram.
# We verify that PLD will not copy and propose candidated that contain an EOS token, even if there are overlapping ngrams
# in input ids. Otherwise a proposed EOS along with the trailing (ngrams-1) tokens might be accepted by the target model.
# That seems as if the model "generated" and EOS but didn't stop from user's perspective
input_ids = torch.randint(1, 50, (1, 10), device=torch_device) # generate inputs in range from 1-50
arbitrary_ngram = 51 # this is the arbitrary ngram, specifically chosen OOV to prevent flaky tests
input_ids[:, 3] = arbitrary_ngram # set pre-eos to arbitrary_ngram which is for sure not present in inputs
input_ids[:, -1] = arbitrary_ngram # put arbitrary_ngram in the end for the necessary match to happen
eos_token_id = torch.tensor([0], device=torch_device)
input_ids[:, 4] = eos_token_id # inject eos-token-id in input ids so that it is located after arbitrary_ngram
# init cand geenerator with max_matching_ngram_size=1 to match per-token
candidate_generator = PromptLookupCandidateGenerator(
eos_token_id=eos_token_id, num_output_tokens=4, max_matching_ngram_size=1
)
output_prompt_lookup = candidate_generator.get_candidates(input_ids)[0]
# PLD shouldn't propose any new tokens based on eos-match
self.assertTrue(output_prompt_lookup.shape[-1] == 10)
def test_generate_with_head_masking(self):
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]