Generation: stop at eos for assisted decoding (#31301)
* fix * move changes to prompt lookup * add test * set eos in assistant model * style * fix flakiness * changes for new `main` * Update tests/generation/test_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/generation/test_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add comment to explain --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
9d6c0641c4
commit
4ab33c2d81
@@ -108,6 +108,9 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
|||||||
self.assistant_model = assistant_model
|
self.assistant_model = assistant_model
|
||||||
self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
|
self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
|
||||||
|
|
||||||
|
# Set eos in assistant same as in target model
|
||||||
|
self.assistant_model.generation_config.eos_token_id = generation_config.eos_token_id
|
||||||
|
|
||||||
# Prepare the kwargs for the assistant model
|
# Prepare the kwargs for the assistant model
|
||||||
assistant_kwargs = {}
|
assistant_kwargs = {}
|
||||||
for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads
|
for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads
|
||||||
@@ -267,6 +270,7 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
eos_token_id: torch.Tensor = None,
|
||||||
num_output_tokens: int = 10,
|
num_output_tokens: int = 10,
|
||||||
max_matching_ngram_size: int = None,
|
max_matching_ngram_size: int = None,
|
||||||
max_length: int = 20,
|
max_length: int = 20,
|
||||||
@@ -274,6 +278,7 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
|
|||||||
self.num_output_tokens = num_output_tokens
|
self.num_output_tokens = num_output_tokens
|
||||||
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
|
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
|
if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
|
||||||
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
|
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
|
||||||
@@ -319,6 +324,15 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
|
|||||||
if start_idx < end_idx:
|
if start_idx < end_idx:
|
||||||
chosen_ids = input_ids[0, start_idx:end_idx]
|
chosen_ids = input_ids[0, start_idx:end_idx]
|
||||||
match_found = True
|
match_found = True
|
||||||
|
|
||||||
|
# remove remaining candidate ids if an "eos" token is found, otherwise the target model may
|
||||||
|
# accept eos and the rest as valid, thus not stopping generation after "eos"
|
||||||
|
# NOTE: below code is written based on the fact that assisted decoding supports only bs=1
|
||||||
|
mask = torch.isin(chosen_ids, self.eos_token_id)
|
||||||
|
match_indices_eos = torch.nonzero(mask)
|
||||||
|
if match_indices_eos.numel() > 0:
|
||||||
|
first_eos_index = match_indices_eos[0].item()
|
||||||
|
chosen_ids = chosen_ids[:first_eos_index]
|
||||||
break
|
break
|
||||||
if match_found:
|
if match_found:
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -725,6 +725,7 @@ class GenerationMixin:
|
|||||||
"""
|
"""
|
||||||
if generation_config.prompt_lookup_num_tokens is not None:
|
if generation_config.prompt_lookup_num_tokens is not None:
|
||||||
candidate_generator = PromptLookupCandidateGenerator(
|
candidate_generator = PromptLookupCandidateGenerator(
|
||||||
|
eos_token_id=generation_config._eos_token_tensor,
|
||||||
num_output_tokens=generation_config.prompt_lookup_num_tokens,
|
num_output_tokens=generation_config.prompt_lookup_num_tokens,
|
||||||
max_matching_ngram_size=generation_config.max_matching_ngram_size,
|
max_matching_ngram_size=generation_config.max_matching_ngram_size,
|
||||||
max_length=generation_config.max_length,
|
max_length=generation_config.max_length,
|
||||||
@@ -3954,7 +3955,6 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# 1. Fetch candidate sequences from a `CandidateGenerator`
|
# 1. Fetch candidate sequences from a `CandidateGenerator`
|
||||||
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
|
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
|
||||||
candidate_input_ids = candidate_input_ids.to(self.device)
|
|
||||||
if candidate_logits is not None:
|
if candidate_logits is not None:
|
||||||
candidate_logits = candidate_logits.to(self.device)
|
candidate_logits = candidate_logits.to(self.device)
|
||||||
|
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ if is_torch_available():
|
|||||||
MaxLengthCriteria,
|
MaxLengthCriteria,
|
||||||
MinLengthLogitsProcessor,
|
MinLengthLogitsProcessor,
|
||||||
PhrasalConstraint,
|
PhrasalConstraint,
|
||||||
|
PromptLookupCandidateGenerator,
|
||||||
SampleDecoderOnlyOutput,
|
SampleDecoderOnlyOutput,
|
||||||
SampleEncoderDecoderOutput,
|
SampleEncoderDecoderOutput,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
@@ -1372,6 +1373,34 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
self._check_outputs(output_assisted, input_ids, model.config, use_cache=True)
|
self._check_outputs(output_assisted, input_ids, model.config, use_cache=True)
|
||||||
|
|
||||||
|
def test_prompt_lookup_decoding_stops_at_eos(self):
|
||||||
|
# This test ensures that the prompt lookup generation stops at eos token and does not suggest more tokens
|
||||||
|
# (see https://github.com/huggingface/transformers/pull/31301)
|
||||||
|
|
||||||
|
# The main idea is to have an ngram (unigram in our case) that is repeated twice in the input ids.
|
||||||
|
# First time at the very end, so input ends with the unigrams, and second any arbitrary location.
|
||||||
|
# Also, we need an EOS token which will be injected just after the arbitrary located ngram.
|
||||||
|
# We verify that PLD will not copy and propose candidated that contain an EOS token, even if there are overlapping ngrams
|
||||||
|
# in input ids. Otherwise a proposed EOS along with the trailing (ngrams-1) tokens might be accepted by the target model.
|
||||||
|
# That seems as if the model "generated" and EOS but didn't stop from user's perspective
|
||||||
|
|
||||||
|
input_ids = torch.randint(1, 50, (1, 10), device=torch_device) # generate inputs in range from 1-50
|
||||||
|
arbitrary_ngram = 51 # this is the arbitrary ngram, specifically chosen OOV to prevent flaky tests
|
||||||
|
input_ids[:, 3] = arbitrary_ngram # set pre-eos to arbitrary_ngram which is for sure not present in inputs
|
||||||
|
input_ids[:, -1] = arbitrary_ngram # put arbitrary_ngram in the end for the necessary match to happen
|
||||||
|
|
||||||
|
eos_token_id = torch.tensor([0], device=torch_device)
|
||||||
|
input_ids[:, 4] = eos_token_id # inject eos-token-id in input ids so that it is located after arbitrary_ngram
|
||||||
|
|
||||||
|
# init cand geenerator with max_matching_ngram_size=1 to match per-token
|
||||||
|
candidate_generator = PromptLookupCandidateGenerator(
|
||||||
|
eos_token_id=eos_token_id, num_output_tokens=4, max_matching_ngram_size=1
|
||||||
|
)
|
||||||
|
output_prompt_lookup = candidate_generator.get_candidates(input_ids)[0]
|
||||||
|
|
||||||
|
# PLD shouldn't propose any new tokens based on eos-match
|
||||||
|
self.assertTrue(output_prompt_lookup.shape[-1] == 10)
|
||||||
|
|
||||||
def test_generate_with_head_masking(self):
|
def test_generate_with_head_masking(self):
|
||||||
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
||||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||||
|
|||||||
Reference in New Issue
Block a user