Adding Prompt lookup decoding (#27775)
* MVP * fix ci * more ci * remove redundant kwarg * added and wired up PromptLookupCandidateGenerator * rebased with main, working * removed print * style fixes * fix test * fixed tests * added test for prompt lookup decoding * fixed circleci * fixed test issue * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/candidate_generator.py * Update src/transformers/generation/candidate_generator.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Joao Gante <joao@huggingface.co> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -1569,6 +1569,66 @@ class GenerationTesterMixin:
|
||||
for output in (output_greedy, output_assisted):
|
||||
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
||||
|
||||
@is_flaky()
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||
# This test ensures that the prompt lookup generation does not introduce output changes over greedy search.
|
||||
# This test is mostly a copy of test_assisted_decoding_matches_greedy_search
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
if any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in [
|
||||
"bigbirdpegasus",
|
||||
"led",
|
||||
"mega",
|
||||
"speech2text",
|
||||
"git",
|
||||
"prophetnet",
|
||||
"seamlessm4t",
|
||||
"clvp",
|
||||
]
|
||||
):
|
||||
self.skipTest("May fix in the future: need model-specific fixes")
|
||||
|
||||
# enable cache
|
||||
config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest("This model doesn't support caching")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
# Sets assisted generation arguments such that:
|
||||
# a) no EOS is generated, to ensure generation doesn't break early
|
||||
# b) the prompt lookup tries to give the model 2 tokens, to ensure the input preparation of
|
||||
# prompt lookup is correct
|
||||
# c) there are at least two forward passes in the main model, to ensure the input preparation of
|
||||
# the main model is correct
|
||||
generation_kwargs = {
|
||||
"eos_token_id": -1, # see a)
|
||||
"max_new_tokens": 4, # see c)
|
||||
"num_beams": 1,
|
||||
"do_sample": False,
|
||||
"output_scores": True,
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": True,
|
||||
"return_dict_in_generate": True,
|
||||
}
|
||||
|
||||
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
|
||||
generation_kwargs.update({"prompt_lookup_num_tokens": 2}) # see b)
|
||||
output_prompt_lookup = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
|
||||
# The two outputs must match and their shape must be as expected
|
||||
self.assertListEqual(output_greedy.sequences.tolist(), output_prompt_lookup.sequences.tolist())
|
||||
for output in (output_greedy, output_prompt_lookup):
|
||||
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
||||
|
||||
def test_assisted_decoding_sample(self):
|
||||
# In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
|
||||
# match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
|
||||
|
||||
Reference in New Issue
Block a user