From 90b4adc1f1111f42eada62ea611895646aaee6b6 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 7 Nov 2023 12:08:28 +0000 Subject: [PATCH] Generate: skip tests on unsupported models instead of passing (#27265) --- tests/generation/test_utils.py | 62 +++++++++++++--------------------- 1 file changed, 24 insertions(+), 38 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 7e2f242c6f..7531502be2 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -749,8 +749,7 @@ class GenerationTesterMixin: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() if not hasattr(config, "use_cache"): - # only relevant if model has "use_cache" - return + self.skipTest("This model doesn't support caching") config.use_cache = True config.is_decoder = True @@ -983,8 +982,7 @@ class GenerationTesterMixin: config.forced_eos_token_id = None if not hasattr(config, "use_cache"): - # only relevant if model has "use_cache" - return + self.skipTest("This model doesn't support caching") model = model_class(config).to(torch_device).eval() if model.config.is_encoder_decoder: @@ -1420,13 +1418,13 @@ class GenerationTesterMixin: for model_class in self.all_generative_model_classes: # won't fix: FSMT and Reformer have a different cache variable type (and format). if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): - return + self.skipTest("Won't fix: old model with different cache format") config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): - return + self.skipTest("This model doesn't support caching") config.use_cache = True config.is_decoder = True @@ -1441,14 +1439,14 @@ class GenerationTesterMixin: for model_class in self.all_generative_model_classes: # won't fix: FSMT and Reformer have a different cache variable type (and format). if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): - return + self.skipTest("Won't fix: old model with different cache format") # enable cache config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): - return + self.skipTest("This model doesn't support caching") config.use_cache = True config.is_decoder = True @@ -1472,18 +1470,16 @@ class GenerationTesterMixin: def test_contrastive_generate_low_memory(self): # Check that choosing 'low_memory' does not change the model output for model_class in self.all_generative_model_classes: - # won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format). - if any( - model_name in model_class.__name__.lower() - for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"] - ): - return + if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]): + self.skipTest("Won't fix: old model with different cache format") + if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]): + self.skipTest("TODO: fix me") config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): - return + self.skipTest("This model doesn't support caching") config.use_cache = True config.is_decoder = True @@ -1510,8 +1506,6 @@ class GenerationTesterMixin: ) self.assertListEqual(low_output.tolist(), high_output.tolist()) - return - @slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%. def test_assisted_decoding_matches_greedy_search(self): # This test ensures that the assisted generation does not introduce output changes over greedy search. @@ -1522,15 +1516,13 @@ class GenerationTesterMixin: # - assisted_decoding does not support `batch_size > 1` for model_class in self.all_generative_model_classes: - # won't fix: FSMT and Reformer have a different cache variable type (and format). if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): - return - # may fix in the future: the following models fail with assisted decoding, and need model-specific fixes + self.skipTest("Won't fix: old model with different cache format") if any( model_name in model_class.__name__.lower() for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"] ): - return + self.skipTest("May fix in the future: need model-specific fixes") # This for loop is a naive and temporary effort to make the test less flaky. failed = 0 @@ -1540,7 +1532,7 @@ class GenerationTesterMixin: # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config, "use_cache"): - return + self.skipTest("This model doesn't support caching") config.use_cache = True config.is_decoder = True @@ -1587,24 +1579,21 @@ class GenerationTesterMixin: def test_assisted_decoding_sample(self): # Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the # exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking). - for model_class in self.all_generative_model_classes: - # won't fix: FSMT and Reformer have a different cache variable type (and format). if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): - return - # may fix in the future: the following models fail with assisted decoding, and need model-specific fixes + self.skipTest("Won't fix: old model with different cache format") if any( model_name in model_class.__name__.lower() for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet", "seamlessm4t"] ): - return + self.skipTest("May fix in the future: need model-specific fixes") # enable cache config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config, "use_cache"): - return + self.skipTest("This model doesn't support caching") config.use_cache = True config.is_decoder = True @@ -1716,7 +1705,7 @@ class GenerationTesterMixin: # If it doesn't support cache, pass the test if not hasattr(config, "use_cache"): - return + self.skipTest("This model doesn't support caching") model = model_class(config).to(torch_device) if "use_cache" not in inputs: @@ -1725,7 +1714,7 @@ class GenerationTesterMixin: # If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format) if "past_key_values" not in outputs: - return + self.skipTest("This model doesn't return `past_key_values`") num_hidden_layers = ( getattr(config, "decoder_layers", None) @@ -1832,18 +1821,15 @@ class GenerationTesterMixin: def test_generate_continue_from_past_key_values(self): # Tests that we can continue generating from past key values, returned from a previous `generate` call for model_class in self.all_generative_model_classes: - # won't fix: old models with unique inputs/caches/others if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]): - return - # may fix in the future: needs modeling or test input preparation fixes for compatibility + self.skipTest("Won't fix: old model with unique inputs/caches/other") if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]): - return + self.skipTest("TODO: needs modeling or test input preparation fixes for compatibility") config, inputs = self.model_tester.prepare_config_and_inputs_for_common() - # If it doesn't support cache, pass the test if not hasattr(config, "use_cache"): - return + self.skipTest("This model doesn't support caching") # Let's make it always: # 1. use cache (for obvious reasons) @@ -1862,10 +1848,10 @@ class GenerationTesterMixin: model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1 model.generation_config.forced_eos_token_id = None - # If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format) + # If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format) outputs = model(**inputs) if "past_key_values" not in outputs: - return + self.skipTest("This model doesn't return `past_key_values`") # Traditional way of generating text, with `return_dict_in_generate` to return the past key values outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True)