Generate: improve assisted generation tests (#27540)
This commit is contained in:
@@ -23,6 +23,7 @@ import numpy as np
|
||||
|
||||
from transformers import is_torch_available, pipeline
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_accelerate,
|
||||
require_torch,
|
||||
require_torch_multi_accelerator,
|
||||
@@ -1506,10 +1507,14 @@ class GenerationTesterMixin:
|
||||
)
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
@slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
|
||||
@is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
|
||||
def test_assisted_decoding_matches_greedy_search(self):
|
||||
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
||||
# It breaks the pattern in the tests above, for multiple reasons:
|
||||
# NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul
|
||||
# shape differences -- and it may result in a different output. The input shape difference happens in the
|
||||
# main model, that runs the forward pass with several candidates at once (as opposed to generating one token at
|
||||
# a time). See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
|
||||
# NOTE (2): It breaks the pattern in the tests above, for multiple reasons:
|
||||
# - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to
|
||||
# prepare the assistant encoder outputs in the main generate body);
|
||||
# - assisted_decoding does not support `use_cache = False`
|
||||
@@ -1520,15 +1525,21 @@ class GenerationTesterMixin:
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
if any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
|
||||
for model_name in [
|
||||
"bigbirdpegasus",
|
||||
"led",
|
||||
"mega",
|
||||
"speech2text",
|
||||
"git",
|
||||
"prophetnet",
|
||||
"seamlessm4t",
|
||||
"clvp",
|
||||
]
|
||||
):
|
||||
self.skipTest("May fix in the future: need model-specific fixes")
|
||||
|
||||
# This for loop is a naive and temporary effort to make the test less flaky.
|
||||
failed = 0
|
||||
for i in range(10):
|
||||
# enable cache
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
|
||||
config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
@@ -1537,60 +1548,59 @@ class GenerationTesterMixin:
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_greedy = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
# Note: with assisted generate, if the same model is used as assistant, then all assistant tokens will
|
||||
# be correct
|
||||
output_assisted = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
assistant_model=model,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
# Sets assisted generation arguments such that:
|
||||
# a) no EOS is generated, to ensure generation doesn't break early
|
||||
# b) the assistant model always generates two tokens when it is called, to ensure the input preparation of
|
||||
# the assistant model is correct
|
||||
# c) there are at least two forward passes in the main model, to ensure the input preparation of
|
||||
# the main model is correct
|
||||
generation_kwargs = {
|
||||
"eos_token_id": -1, # see a)
|
||||
"max_new_tokens": 4, # see c)
|
||||
"num_beams": 1,
|
||||
"do_sample": False,
|
||||
"output_scores": True,
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": True,
|
||||
"return_dict_in_generate": True,
|
||||
}
|
||||
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
|
||||
try:
|
||||
assistant_model = model
|
||||
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
|
||||
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
|
||||
generation_kwargs.update({"assistant_model": assistant_model})
|
||||
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
|
||||
# The two outputs must match and their shape must be as expected
|
||||
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
|
||||
|
||||
for output in (output_greedy, output_assisted):
|
||||
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
||||
except AssertionError:
|
||||
failed += 1
|
||||
if failed > 1:
|
||||
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
|
||||
|
||||
for output in (output_greedy, output_assisted):
|
||||
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
||||
|
||||
@unittest.skip("Failing for a lot of models du to attention mask size missmatch. Works well when standalone.")
|
||||
def test_assisted_decoding_sample(self):
|
||||
# Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the
|
||||
# exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking).
|
||||
# In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
|
||||
# match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
|
||||
# different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535).
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
if any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet", "seamlessm4t"]
|
||||
for model_name in [
|
||||
"bigbirdpegasus",
|
||||
"led",
|
||||
"mega",
|
||||
"speech2text",
|
||||
"git",
|
||||
"prophetnet",
|
||||
"seamlessm4t",
|
||||
"clvp",
|
||||
]
|
||||
):
|
||||
self.skipTest("May fix in the future: need model-specific fixes")
|
||||
|
||||
# enable cache
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
|
||||
config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
@@ -1599,18 +1609,27 @@ class GenerationTesterMixin:
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_assisted = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
num_beams=1,
|
||||
do_sample=True,
|
||||
assistant_model=model, # triggers assisted decoding
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
# Sets assisted generation arguments such that:
|
||||
# a) no EOS is generated, to ensure generation doesn't break early
|
||||
# b) the assistant model always generates two tokens when it is called, to ensure the input preparation of
|
||||
# the assistant model is correct
|
||||
# c) there are at least two forward passes in the main model, to ensure the input preparation of
|
||||
# the main model is correct
|
||||
assistant_model = model
|
||||
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
|
||||
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
|
||||
generation_kwargs = {
|
||||
"eos_token_id": -1, # see a)
|
||||
"max_new_tokens": 4, # see c)
|
||||
"num_beams": 1,
|
||||
"do_sample": True,
|
||||
"assistant_model": assistant_model,
|
||||
"output_scores": True,
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": True,
|
||||
"return_dict_in_generate": True,
|
||||
}
|
||||
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
|
||||
self._check_outputs(output_assisted, input_ids, model.config, use_cache=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user