From 12b50c6130ccfbb2381a181dc4bbb2bf07a3c62a Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 16 Nov 2023 18:54:20 +0000 Subject: [PATCH] Generate: improve assisted generation tests (#27540) --- tests/generation/test_utils.py | 181 ++++++++++++++++++--------------- 1 file changed, 100 insertions(+), 81 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 1705142f8f..1e76c88c71 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -23,6 +23,7 @@ import numpy as np from transformers import is_torch_available, pipeline from transformers.testing_utils import ( + is_flaky, require_accelerate, require_torch, require_torch_multi_accelerator, @@ -1506,10 +1507,14 @@ class GenerationTesterMixin: ) self.assertListEqual(low_output.tolist(), high_output.tolist()) - @slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%. + @is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail. def test_assisted_decoding_matches_greedy_search(self): # This test ensures that the assisted generation does not introduce output changes over greedy search. - # It breaks the pattern in the tests above, for multiple reasons: + # NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul + # shape differences -- and it may result in a different output. The input shape difference happens in the + # main model, that runs the forward pass with several candidates at once (as opposed to generating one token at + # a time). See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info. + # NOTE (2): It breaks the pattern in the tests above, for multiple reasons: # - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to # prepare the assistant encoder outputs in the main generate body); # - assisted_decoding does not support `use_cache = False` @@ -1520,77 +1525,21 @@ class GenerationTesterMixin: self.skipTest("Won't fix: old model with different cache format") if any( model_name in model_class.__name__.lower() - for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"] - ): - self.skipTest("May fix in the future: need model-specific fixes") - - # This for loop is a naive and temporary effort to make the test less flaky. - failed = 0 - for i in range(10): - # enable cache - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) - - # NOTE: assisted generation only works with cache on at the moment. - if not hasattr(config, "use_cache"): - self.skipTest("This model doesn't support caching") - - config.use_cache = True - config.is_decoder = True - model = model_class(config).to(torch_device).eval() - output_greedy = model.generate( - input_ids, - attention_mask=attention_mask, - max_length=max_length, - num_beams=1, - do_sample=False, - output_scores=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) - # Note: with assisted generate, if the same model is used as assistant, then all assistant tokens will - # be correct - output_assisted = model.generate( - input_ids, - attention_mask=attention_mask, - max_length=max_length, - num_beams=1, - do_sample=False, - assistant_model=model, - output_scores=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) - - try: - self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist()) - - for output in (output_greedy, output_assisted): - self._check_outputs(output, input_ids, model.config, use_cache=True) - except AssertionError: - failed += 1 - if failed > 1: - self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist()) - - for output in (output_greedy, output_assisted): - self._check_outputs(output, input_ids, model.config, use_cache=True) - - @unittest.skip("Failing for a lot of models du to attention mask size missmatch. Works well when standalone.") - def test_assisted_decoding_sample(self): - # Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the - # exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking). - for model_class in self.all_generative_model_classes: - if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): - self.skipTest("Won't fix: old model with different cache format") - if any( - model_name in model_class.__name__.lower() - for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet", "seamlessm4t"] + for model_name in [ + "bigbirdpegasus", + "led", + "mega", + "speech2text", + "git", + "prophetnet", + "seamlessm4t", + "clvp", + ] ): self.skipTest("May fix in the future: need model-specific fixes") # enable cache - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) + config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1) # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config, "use_cache"): @@ -1599,18 +1548,88 @@ class GenerationTesterMixin: config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() - output_assisted = model.generate( - input_ids, - attention_mask=attention_mask, - max_length=max_length, - num_beams=1, - do_sample=True, - assistant_model=model, # triggers assisted decoding - output_scores=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) + # Sets assisted generation arguments such that: + # a) no EOS is generated, to ensure generation doesn't break early + # b) the assistant model always generates two tokens when it is called, to ensure the input preparation of + # the assistant model is correct + # c) there are at least two forward passes in the main model, to ensure the input preparation of + # the main model is correct + generation_kwargs = { + "eos_token_id": -1, # see a) + "max_new_tokens": 4, # see c) + "num_beams": 1, + "do_sample": False, + "output_scores": True, + "output_hidden_states": True, + "output_attentions": True, + "return_dict_in_generate": True, + } + output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 2 # see b) + assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b) + generation_kwargs.update({"assistant_model": assistant_model}) + output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + + # The two outputs must match and their shape must be as expected + self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist()) + for output in (output_greedy, output_assisted): + self._check_outputs(output, input_ids, model.config, use_cache=True) + + def test_assisted_decoding_sample(self): + # In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not + # match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with + # different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535). + for model_class in self.all_generative_model_classes: + if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): + self.skipTest("Won't fix: old model with different cache format") + if any( + model_name in model_class.__name__.lower() + for model_name in [ + "bigbirdpegasus", + "led", + "mega", + "speech2text", + "git", + "prophetnet", + "seamlessm4t", + "clvp", + ] + ): + self.skipTest("May fix in the future: need model-specific fixes") + + # enable cache + config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1) + + # NOTE: assisted generation only works with cache on at the moment. + if not hasattr(config, "use_cache"): + self.skipTest("This model doesn't support caching") + + config.use_cache = True + config.is_decoder = True + model = model_class(config).to(torch_device).eval() + # Sets assisted generation arguments such that: + # a) no EOS is generated, to ensure generation doesn't break early + # b) the assistant model always generates two tokens when it is called, to ensure the input preparation of + # the assistant model is correct + # c) there are at least two forward passes in the main model, to ensure the input preparation of + # the main model is correct + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 2 # see b) + assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b) + generation_kwargs = { + "eos_token_id": -1, # see a) + "max_new_tokens": 4, # see c) + "num_beams": 1, + "do_sample": True, + "assistant_model": assistant_model, + "output_scores": True, + "output_hidden_states": True, + "output_attentions": True, + "return_dict_in_generate": True, + } + output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) self._check_outputs(output_assisted, input_ids, model.config, use_cache=True)