Generate: improve assisted generation tests (#27540)
This commit is contained in:
@@ -23,6 +23,7 @@ import numpy as np
|
|||||||
|
|
||||||
from transformers import is_torch_available, pipeline
|
from transformers import is_torch_available, pipeline
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
is_flaky,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_multi_accelerator,
|
require_torch_multi_accelerator,
|
||||||
@@ -1506,10 +1507,14 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||||
|
|
||||||
@slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
|
@is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
|
||||||
def test_assisted_decoding_matches_greedy_search(self):
|
def test_assisted_decoding_matches_greedy_search(self):
|
||||||
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
||||||
# It breaks the pattern in the tests above, for multiple reasons:
|
# NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul
|
||||||
|
# shape differences -- and it may result in a different output. The input shape difference happens in the
|
||||||
|
# main model, that runs the forward pass with several candidates at once (as opposed to generating one token at
|
||||||
|
# a time). See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
|
||||||
|
# NOTE (2): It breaks the pattern in the tests above, for multiple reasons:
|
||||||
# - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to
|
# - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to
|
||||||
# prepare the assistant encoder outputs in the main generate body);
|
# prepare the assistant encoder outputs in the main generate body);
|
||||||
# - assisted_decoding does not support `use_cache = False`
|
# - assisted_decoding does not support `use_cache = False`
|
||||||
@@ -1520,77 +1525,21 @@ class GenerationTesterMixin:
|
|||||||
self.skipTest("Won't fix: old model with different cache format")
|
self.skipTest("Won't fix: old model with different cache format")
|
||||||
if any(
|
if any(
|
||||||
model_name in model_class.__name__.lower()
|
model_name in model_class.__name__.lower()
|
||||||
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
|
for model_name in [
|
||||||
):
|
"bigbirdpegasus",
|
||||||
self.skipTest("May fix in the future: need model-specific fixes")
|
"led",
|
||||||
|
"mega",
|
||||||
# This for loop is a naive and temporary effort to make the test less flaky.
|
"speech2text",
|
||||||
failed = 0
|
"git",
|
||||||
for i in range(10):
|
"prophetnet",
|
||||||
# enable cache
|
"seamlessm4t",
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
|
"clvp",
|
||||||
|
]
|
||||||
# NOTE: assisted generation only works with cache on at the moment.
|
|
||||||
if not hasattr(config, "use_cache"):
|
|
||||||
self.skipTest("This model doesn't support caching")
|
|
||||||
|
|
||||||
config.use_cache = True
|
|
||||||
config.is_decoder = True
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
output_greedy = model.generate(
|
|
||||||
input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
max_length=max_length,
|
|
||||||
num_beams=1,
|
|
||||||
do_sample=False,
|
|
||||||
output_scores=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
# Note: with assisted generate, if the same model is used as assistant, then all assistant tokens will
|
|
||||||
# be correct
|
|
||||||
output_assisted = model.generate(
|
|
||||||
input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
max_length=max_length,
|
|
||||||
num_beams=1,
|
|
||||||
do_sample=False,
|
|
||||||
assistant_model=model,
|
|
||||||
output_scores=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
|
|
||||||
|
|
||||||
for output in (output_greedy, output_assisted):
|
|
||||||
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
|
||||||
except AssertionError:
|
|
||||||
failed += 1
|
|
||||||
if failed > 1:
|
|
||||||
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
|
|
||||||
|
|
||||||
for output in (output_greedy, output_assisted):
|
|
||||||
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
|
||||||
|
|
||||||
@unittest.skip("Failing for a lot of models du to attention mask size missmatch. Works well when standalone.")
|
|
||||||
def test_assisted_decoding_sample(self):
|
|
||||||
# Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the
|
|
||||||
# exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking).
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
|
||||||
self.skipTest("Won't fix: old model with different cache format")
|
|
||||||
if any(
|
|
||||||
model_name in model_class.__name__.lower()
|
|
||||||
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet", "seamlessm4t"]
|
|
||||||
):
|
):
|
||||||
self.skipTest("May fix in the future: need model-specific fixes")
|
self.skipTest("May fix in the future: need model-specific fixes")
|
||||||
|
|
||||||
# enable cache
|
# enable cache
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
|
config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)
|
||||||
|
|
||||||
# NOTE: assisted generation only works with cache on at the moment.
|
# NOTE: assisted generation only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
@@ -1599,18 +1548,88 @@ class GenerationTesterMixin:
|
|||||||
config.use_cache = True
|
config.use_cache = True
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_assisted = model.generate(
|
# Sets assisted generation arguments such that:
|
||||||
input_ids,
|
# a) no EOS is generated, to ensure generation doesn't break early
|
||||||
attention_mask=attention_mask,
|
# b) the assistant model always generates two tokens when it is called, to ensure the input preparation of
|
||||||
max_length=max_length,
|
# the assistant model is correct
|
||||||
num_beams=1,
|
# c) there are at least two forward passes in the main model, to ensure the input preparation of
|
||||||
do_sample=True,
|
# the main model is correct
|
||||||
assistant_model=model, # triggers assisted decoding
|
generation_kwargs = {
|
||||||
output_scores=True,
|
"eos_token_id": -1, # see a)
|
||||||
output_hidden_states=True,
|
"max_new_tokens": 4, # see c)
|
||||||
output_attentions=True,
|
"num_beams": 1,
|
||||||
return_dict_in_generate=True,
|
"do_sample": False,
|
||||||
)
|
"output_scores": True,
|
||||||
|
"output_hidden_states": True,
|
||||||
|
"output_attentions": True,
|
||||||
|
"return_dict_in_generate": True,
|
||||||
|
}
|
||||||
|
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||||
|
|
||||||
|
assistant_model = model
|
||||||
|
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
|
||||||
|
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
|
||||||
|
generation_kwargs.update({"assistant_model": assistant_model})
|
||||||
|
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||||
|
|
||||||
|
# The two outputs must match and their shape must be as expected
|
||||||
|
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
|
||||||
|
for output in (output_greedy, output_assisted):
|
||||||
|
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
||||||
|
|
||||||
|
def test_assisted_decoding_sample(self):
|
||||||
|
# In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
|
||||||
|
# match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
|
||||||
|
# different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535).
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||||
|
self.skipTest("Won't fix: old model with different cache format")
|
||||||
|
if any(
|
||||||
|
model_name in model_class.__name__.lower()
|
||||||
|
for model_name in [
|
||||||
|
"bigbirdpegasus",
|
||||||
|
"led",
|
||||||
|
"mega",
|
||||||
|
"speech2text",
|
||||||
|
"git",
|
||||||
|
"prophetnet",
|
||||||
|
"seamlessm4t",
|
||||||
|
"clvp",
|
||||||
|
]
|
||||||
|
):
|
||||||
|
self.skipTest("May fix in the future: need model-specific fixes")
|
||||||
|
|
||||||
|
# enable cache
|
||||||
|
config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)
|
||||||
|
|
||||||
|
# NOTE: assisted generation only works with cache on at the moment.
|
||||||
|
if not hasattr(config, "use_cache"):
|
||||||
|
self.skipTest("This model doesn't support caching")
|
||||||
|
|
||||||
|
config.use_cache = True
|
||||||
|
config.is_decoder = True
|
||||||
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
# Sets assisted generation arguments such that:
|
||||||
|
# a) no EOS is generated, to ensure generation doesn't break early
|
||||||
|
# b) the assistant model always generates two tokens when it is called, to ensure the input preparation of
|
||||||
|
# the assistant model is correct
|
||||||
|
# c) there are at least two forward passes in the main model, to ensure the input preparation of
|
||||||
|
# the main model is correct
|
||||||
|
assistant_model = model
|
||||||
|
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
|
||||||
|
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
|
||||||
|
generation_kwargs = {
|
||||||
|
"eos_token_id": -1, # see a)
|
||||||
|
"max_new_tokens": 4, # see c)
|
||||||
|
"num_beams": 1,
|
||||||
|
"do_sample": True,
|
||||||
|
"assistant_model": assistant_model,
|
||||||
|
"output_scores": True,
|
||||||
|
"output_hidden_states": True,
|
||||||
|
"output_attentions": True,
|
||||||
|
"return_dict_in_generate": True,
|
||||||
|
}
|
||||||
|
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||||
|
|
||||||
self._check_outputs(output_assisted, input_ids, model.config, use_cache=True)
|
self._check_outputs(output_assisted, input_ids, model.config, use_cache=True)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user