Skip tests properly (#31308)

* Skip tests properly

* [test_all]

* Add 'reason' as kwarg for skipTest

* [test_all] Fix up

* [test_all]
This commit is contained in:
amyeroberts
2024-06-26 21:59:08 +01:00
committed by GitHub
parent 1f9f57ab4c
commit 1de7dc7403
254 changed files with 1721 additions and 1298 deletions

View File

@@ -463,9 +463,9 @@ class GenerationTesterMixin:
config, input_ids, attention_mask = self._get_input_ids_and_config()
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")
self.skipTest(reason="This model doesn't support caching")
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
self.skipTest("Won't fix: model with non-standard dictionary output shapes")
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
config.use_cache = True
config.is_decoder = True
@@ -625,9 +625,9 @@ class GenerationTesterMixin:
config, input_ids, attention_mask = self._get_input_ids_and_config()
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")
self.skipTest(reason="This model doesn't support caching")
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
self.skipTest("Won't fix: model with non-standard dictionary output shapes")
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
@@ -667,7 +667,7 @@ class GenerationTesterMixin:
def test_model_parallel_beam_search(self):
for model_class in self.all_generative_model_classes:
if "xpu" in torch_device:
return unittest.skip("device_map='auto' does not work with XPU devices")
return unittest.skip(reason="device_map='auto' does not work with XPU devices")
if model_class._no_split_modules is None:
continue
@@ -765,7 +765,7 @@ class GenerationTesterMixin:
# if no bos token id => cannot generate from None
if config.bos_token_id is None:
return
self.skipTest(reason="bos_token_id is None")
# hack in case they are equal, otherwise the attn mask will be [0]
if config.bos_token_id == config.pad_token_id:
@@ -982,17 +982,17 @@ class GenerationTesterMixin:
def test_contrastive_generate(self):
for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest("Stateful models don't support contrastive search generation")
self.skipTest(reason="Stateful models don't support contrastive search generation")
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format")
self.skipTest(reason="Won't fix: old model with different cache format")
config, input_ids, attention_mask = self._get_input_ids_and_config()
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")
self.skipTest(reason="This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
@@ -1009,17 +1009,17 @@ class GenerationTesterMixin:
def test_contrastive_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest("Stateful models don't support contrastive search generation")
self.skipTest(reason="Stateful models don't support contrastive search generation")
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format")
self.skipTest(reason="Won't fix: old model with different cache format")
config, input_ids, attention_mask = self._get_input_ids_and_config()
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")
self.skipTest(reason="This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
@@ -1045,18 +1045,18 @@ class GenerationTesterMixin:
# Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest("Stateful models don't support contrastive search generation")
self.skipTest(reason="Stateful models don't support contrastive search generation")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]):
self.skipTest("Won't fix: old model with different cache format")
self.skipTest(reason="Won't fix: old model with different cache format")
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]):
self.skipTest("TODO: fix me")
self.skipTest(reason="TODO: fix me")
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")
self.skipTest(reason="This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
@@ -1087,9 +1087,9 @@ class GenerationTesterMixin:
# Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest("May fix in the future: need custom cache handling")
self.skipTest(reason="May fix in the future: need custom cache handling")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format")
self.skipTest(reason="Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in [
@@ -1102,7 +1102,7 @@ class GenerationTesterMixin:
"jamba",
]
):
self.skipTest("May fix in the future: need model-specific fixes")
self.skipTest(reason="May fix in the future: need model-specific fixes")
config, input_ids, _ = self._get_input_ids_and_config(batch_size=2)
# batch_size=1 is ok, but batch_size>1 will cause non-identical output
@@ -1135,9 +1135,9 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest("Stateful models don't support assisted generation")
self.skipTest(reason="Stateful models don't support assisted generation")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format")
self.skipTest(reason="Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in [
@@ -1151,14 +1151,14 @@ class GenerationTesterMixin:
"clvp",
]
):
self.skipTest("May fix in the future: need model-specific fixes")
self.skipTest(reason="May fix in the future: need model-specific fixes")
# enable cache
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")
self.skipTest(reason="This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
@@ -1206,9 +1206,9 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest("Stateful models don't support assisted generation")
self.skipTest(reason="Stateful models don't support assisted generation")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format")
self.skipTest(reason="Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in [
@@ -1222,14 +1222,14 @@ class GenerationTesterMixin:
"clvp",
]
):
self.skipTest("May fix in the future: need model-specific fixes")
self.skipTest(reason="May fix in the future: need model-specific fixes")
# enable cache
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")
self.skipTest(reason="This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
@@ -1268,9 +1268,9 @@ class GenerationTesterMixin:
# different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535).
for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest("Stateful models don't support assisted generation")
self.skipTest(reason="Stateful models don't support assisted generation")
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
self.skipTest("Won't fix: old model with different cache format")
self.skipTest(reason="Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in [
@@ -1284,14 +1284,14 @@ class GenerationTesterMixin:
"clvp",
]
):
self.skipTest("May fix in the future: need model-specific fixes")
self.skipTest(reason="May fix in the future: need model-specific fixes")
# enable cache
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")
self.skipTest(reason="This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
@@ -1436,7 +1436,7 @@ class GenerationTesterMixin:
# If it doesn't support cache, pass the test
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")
self.skipTest(reason="This model doesn't support caching")
model = model_class(config).to(torch_device)
if "use_cache" not in inputs:
@@ -1445,7 +1445,7 @@ class GenerationTesterMixin:
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
if "past_key_values" not in outputs:
self.skipTest("This model doesn't return `past_key_values`")
self.skipTest(reason="This model doesn't return `past_key_values`")
num_hidden_layers = (
getattr(config, "decoder_layers", None)
@@ -1553,14 +1553,14 @@ class GenerationTesterMixin:
# Tests that we can continue generating from past key values, returned from a previous `generate` call
for model_class in self.all_generative_model_classes:
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
self.skipTest("Won't fix: old model with unique inputs/caches/other")
self.skipTest(reason="Won't fix: old model with unique inputs/caches/other")
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
self.skipTest("TODO: needs modeling or test input preparation fixes for compatibility")
self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility")
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")
self.skipTest(reason="This model doesn't support caching")
# Let's make it always:
# 1. use cache (for obvious reasons)
@@ -1582,7 +1582,7 @@ class GenerationTesterMixin:
# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
outputs = model(**inputs)
if "past_key_values" not in outputs:
self.skipTest("This model doesn't return `past_key_values`")
self.skipTest(reason="This model doesn't return `past_key_values`")
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True)
@@ -1632,7 +1632,7 @@ class GenerationTesterMixin:
# 👉 tests with and without sampling so we can cover the most common use cases.
for model_class in self.all_generative_model_classes:
if not model_class._supports_cache_class:
self.skipTest("This model does not support the new cache format")
self.skipTest(reason="This model does not support the new cache format")
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True
@@ -1689,7 +1689,7 @@ class GenerationTesterMixin:
def test_generate_with_quant_cache(self):
for model_class in self.all_generative_model_classes:
if not model_class._supports_quantized_cache:
self.skipTest("This model does not support the quantized cache format")
self.skipTest(reason="This model does not support the quantized cache format")
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True