Skip tests properly (#31308)
* Skip tests properly * [test_all] * Add 'reason' as kwarg for skipTest * [test_all] Fix up * [test_all]
This commit is contained in:
@@ -463,9 +463,9 @@ class GenerationTesterMixin:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest("This model doesn't support caching")
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
|
||||
self.skipTest("Won't fix: model with non-standard dictionary output shapes")
|
||||
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
@@ -625,9 +625,9 @@ class GenerationTesterMixin:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest("This model doesn't support caching")
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
|
||||
self.skipTest("Won't fix: model with non-standard dictionary output shapes")
|
||||
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
@@ -667,7 +667,7 @@ class GenerationTesterMixin:
|
||||
def test_model_parallel_beam_search(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if "xpu" in torch_device:
|
||||
return unittest.skip("device_map='auto' does not work with XPU devices")
|
||||
return unittest.skip(reason="device_map='auto' does not work with XPU devices")
|
||||
|
||||
if model_class._no_split_modules is None:
|
||||
continue
|
||||
@@ -765,7 +765,7 @@ class GenerationTesterMixin:
|
||||
|
||||
# if no bos token id => cannot generate from None
|
||||
if config.bos_token_id is None:
|
||||
return
|
||||
self.skipTest(reason="bos_token_id is None")
|
||||
|
||||
# hack in case they are equal, otherwise the attn mask will be [0]
|
||||
if config.bos_token_id == config.pad_token_id:
|
||||
@@ -982,17 +982,17 @@ class GenerationTesterMixin:
|
||||
def test_contrastive_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest("Stateful models don't support contrastive search generation")
|
||||
self.skipTest(reason="Stateful models don't support contrastive search generation")
|
||||
|
||||
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
self.skipTest(reason="Won't fix: old model with different cache format")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest("This model doesn't support caching")
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
@@ -1009,17 +1009,17 @@ class GenerationTesterMixin:
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest("Stateful models don't support contrastive search generation")
|
||||
self.skipTest(reason="Stateful models don't support contrastive search generation")
|
||||
|
||||
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
self.skipTest(reason="Won't fix: old model with different cache format")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest("This model doesn't support caching")
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
@@ -1045,18 +1045,18 @@ class GenerationTesterMixin:
|
||||
# Check that choosing 'low_memory' does not change the model output
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest("Stateful models don't support contrastive search generation")
|
||||
self.skipTest(reason="Stateful models don't support contrastive search generation")
|
||||
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
self.skipTest(reason="Won't fix: old model with different cache format")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]):
|
||||
self.skipTest("TODO: fix me")
|
||||
self.skipTest(reason="TODO: fix me")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
||||
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest("This model doesn't support caching")
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
@@ -1087,9 +1087,9 @@ class GenerationTesterMixin:
|
||||
# Check that choosing 'low_memory' does not change the model output
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest("May fix in the future: need custom cache handling")
|
||||
self.skipTest(reason="May fix in the future: need custom cache handling")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
self.skipTest(reason="Won't fix: old model with different cache format")
|
||||
if any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in [
|
||||
@@ -1102,7 +1102,7 @@ class GenerationTesterMixin:
|
||||
"jamba",
|
||||
]
|
||||
):
|
||||
self.skipTest("May fix in the future: need model-specific fixes")
|
||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||
config, input_ids, _ = self._get_input_ids_and_config(batch_size=2)
|
||||
# batch_size=1 is ok, but batch_size>1 will cause non-identical output
|
||||
|
||||
@@ -1135,9 +1135,9 @@ class GenerationTesterMixin:
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest("Stateful models don't support assisted generation")
|
||||
self.skipTest(reason="Stateful models don't support assisted generation")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
self.skipTest(reason="Won't fix: old model with different cache format")
|
||||
if any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in [
|
||||
@@ -1151,14 +1151,14 @@ class GenerationTesterMixin:
|
||||
"clvp",
|
||||
]
|
||||
):
|
||||
self.skipTest("May fix in the future: need model-specific fixes")
|
||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||
|
||||
# enable cache
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest("This model doesn't support caching")
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
@@ -1206,9 +1206,9 @@ class GenerationTesterMixin:
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest("Stateful models don't support assisted generation")
|
||||
self.skipTest(reason="Stateful models don't support assisted generation")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
self.skipTest(reason="Won't fix: old model with different cache format")
|
||||
if any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in [
|
||||
@@ -1222,14 +1222,14 @@ class GenerationTesterMixin:
|
||||
"clvp",
|
||||
]
|
||||
):
|
||||
self.skipTest("May fix in the future: need model-specific fixes")
|
||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||
|
||||
# enable cache
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest("This model doesn't support caching")
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
@@ -1268,9 +1268,9 @@ class GenerationTesterMixin:
|
||||
# different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535).
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest("Stateful models don't support assisted generation")
|
||||
self.skipTest(reason="Stateful models don't support assisted generation")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
self.skipTest(reason="Won't fix: old model with different cache format")
|
||||
if any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in [
|
||||
@@ -1284,14 +1284,14 @@ class GenerationTesterMixin:
|
||||
"clvp",
|
||||
]
|
||||
):
|
||||
self.skipTest("May fix in the future: need model-specific fixes")
|
||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||
|
||||
# enable cache
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest("This model doesn't support caching")
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
@@ -1436,7 +1436,7 @@ class GenerationTesterMixin:
|
||||
|
||||
# If it doesn't support cache, pass the test
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest("This model doesn't support caching")
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
if "use_cache" not in inputs:
|
||||
@@ -1445,7 +1445,7 @@ class GenerationTesterMixin:
|
||||
|
||||
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
|
||||
if "past_key_values" not in outputs:
|
||||
self.skipTest("This model doesn't return `past_key_values`")
|
||||
self.skipTest(reason="This model doesn't return `past_key_values`")
|
||||
|
||||
num_hidden_layers = (
|
||||
getattr(config, "decoder_layers", None)
|
||||
@@ -1553,14 +1553,14 @@ class GenerationTesterMixin:
|
||||
# Tests that we can continue generating from past key values, returned from a previous `generate` call
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
|
||||
self.skipTest("Won't fix: old model with unique inputs/caches/other")
|
||||
self.skipTest(reason="Won't fix: old model with unique inputs/caches/other")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
|
||||
self.skipTest("TODO: needs modeling or test input preparation fixes for compatibility")
|
||||
self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility")
|
||||
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest("This model doesn't support caching")
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
|
||||
# Let's make it always:
|
||||
# 1. use cache (for obvious reasons)
|
||||
@@ -1582,7 +1582,7 @@ class GenerationTesterMixin:
|
||||
# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
|
||||
outputs = model(**inputs)
|
||||
if "past_key_values" not in outputs:
|
||||
self.skipTest("This model doesn't return `past_key_values`")
|
||||
self.skipTest(reason="This model doesn't return `past_key_values`")
|
||||
|
||||
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
|
||||
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True)
|
||||
@@ -1632,7 +1632,7 @@ class GenerationTesterMixin:
|
||||
# 👉 tests with and without sampling so we can cover the most common use cases.
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_cache_class:
|
||||
self.skipTest("This model does not support the new cache format")
|
||||
self.skipTest(reason="This model does not support the new cache format")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = True
|
||||
@@ -1689,7 +1689,7 @@ class GenerationTesterMixin:
|
||||
def test_generate_with_quant_cache(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_quantized_cache:
|
||||
self.skipTest("This model does not support the quantized cache format")
|
||||
self.skipTest(reason="This model does not support the quantized cache format")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = True
|
||||
|
||||
Reference in New Issue
Block a user