Skip tests properly (#31308)
* Skip tests properly * [test_all] * Add 'reason' as kwarg for skipTest * [test_all] Fix up * [test_all]
This commit is contained in:
@@ -298,7 +298,7 @@ class ModelTesterMixin:
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class._keep_in_fp32_modules is None:
|
||||
return
|
||||
self.skipTest(reason="Model class has no _keep_in_fp32_modules attribute defined")
|
||||
|
||||
model = model_class(config)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@@ -392,7 +392,8 @@ class ModelTesterMixin:
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if config.__class__ not in MODEL_MAPPING:
|
||||
return
|
||||
self.skipTest(reason="Model class not in MODEL_MAPPING")
|
||||
|
||||
base_class = MODEL_MAPPING[config.__class__]
|
||||
|
||||
if isinstance(base_class, tuple):
|
||||
@@ -522,94 +523,11 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertEqual(tied_params1, tied_params2)
|
||||
|
||||
def test_fast_init_context_manager(self):
|
||||
# 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
|
||||
class MyClass(PreTrainedModel):
|
||||
config_class = PretrainedConfig
|
||||
|
||||
def __init__(self, config=None):
|
||||
super().__init__(config if config is not None else PretrainedConfig())
|
||||
self.linear = nn.Linear(10, 10, bias=True)
|
||||
self.embedding = nn.Embedding(10, 10)
|
||||
self.std = 1
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data = nn.init.kaiming_uniform_(module.weight.data, np.sqrt(5))
|
||||
if module.bias is not None:
|
||||
module.bias.data.normal_(mean=0.0, std=self.std)
|
||||
|
||||
# 2. Make sure a linear layer's reset params is properly skipped:
|
||||
with ContextManagers([no_init_weights(True)]):
|
||||
no_init_instance = MyClass()
|
||||
|
||||
set_seed(0)
|
||||
expected_bias = torch.tensor(
|
||||
([0.2975, 0.2131, -0.1379, -0.0796, -0.3012, -0.0057, -0.2381, -0.2439, -0.0174, 0.0475])
|
||||
)
|
||||
init_instance = MyClass()
|
||||
torch.testing.assert_close(init_instance.linear.bias, expected_bias, rtol=1e-3, atol=1e-4)
|
||||
|
||||
set_seed(0)
|
||||
torch.testing.assert_close(
|
||||
init_instance.linear.weight, nn.init.kaiming_uniform_(no_init_instance.linear.weight, np.sqrt(5))
|
||||
)
|
||||
|
||||
# 3. Make sure weights that are not present use init_weight_ and get expected values
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
state_dict = init_instance.state_dict()
|
||||
del state_dict["linear.weight"]
|
||||
|
||||
init_instance.config.save_pretrained(tmpdirname)
|
||||
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
|
||||
set_seed(0)
|
||||
model_fast_init = MyClass.from_pretrained(tmpdirname)
|
||||
|
||||
set_seed(0)
|
||||
model_slow_init = MyClass.from_pretrained(tmpdirname, _fast_init=False)
|
||||
|
||||
for key in model_fast_init.state_dict().keys():
|
||||
max_diff = torch.max(torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]))
|
||||
self.assertLessEqual(max_diff.item(), 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_fast_init_tied_embeddings(self):
|
||||
class MyClass(PreTrainedModel):
|
||||
config_class = PretrainedConfig
|
||||
_tied_weights_keys = ["output_embeddings.weight"]
|
||||
|
||||
def __init__(self, config=None):
|
||||
super().__init__(config if config is not None else PretrainedConfig())
|
||||
self.input_embeddings = nn.Embedding(10, 10)
|
||||
self.output_embeddings = nn.Linear(10, 10, bias=False)
|
||||
self.tie_weights()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.output_embeddings
|
||||
|
||||
def set_output_embeddings(self, output_embeddings):
|
||||
self.output_embeddings = output_embeddings
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.input_embeddings
|
||||
|
||||
def set_input_embeddings(self, input_embeddings):
|
||||
self.input_embeddings = input_embeddings
|
||||
|
||||
def _init_weights(self, module):
|
||||
if module is self.output_embeddings:
|
||||
raise ValueError("unnecessarily initialized tied output embedding!")
|
||||
|
||||
model = MyClass()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
# throws if it initializes the tied output_embeddings
|
||||
MyClass.from_pretrained(tmpdirname)
|
||||
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if config.__class__ not in MODEL_MAPPING:
|
||||
return
|
||||
self.skipTest(reason="Model class not in MODEL_MAPPING")
|
||||
|
||||
base_class = MODEL_MAPPING[config.__class__]
|
||||
|
||||
if isinstance(base_class, tuple):
|
||||
@@ -664,7 +582,8 @@ class ModelTesterMixin:
|
||||
def test_torch_save_load(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if config.__class__ not in MODEL_MAPPING:
|
||||
return
|
||||
self.skipTest(reason="Model class not in MODEL_MAPPING")
|
||||
|
||||
base_class = MODEL_MAPPING[config.__class__]
|
||||
|
||||
if isinstance(base_class, tuple):
|
||||
@@ -748,38 +667,6 @@ class ModelTesterMixin:
|
||||
else:
|
||||
check_determinism(first, second)
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
expected_arg_names = [
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
]
|
||||
expected_arg_names.extend(
|
||||
["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
|
||||
if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
|
||||
else ["encoder_outputs"]
|
||||
)
|
||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||
elif model_class.__name__ in [*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)] and self.has_attentions:
|
||||
expected_arg_names = ["pixel_values", "output_hidden_states", "output_attentions", "return_dict"]
|
||||
self.assertListEqual(arg_names, expected_arg_names)
|
||||
elif model_class.__name__ in [*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)] and not self.has_attentions:
|
||||
expected_arg_names = ["pixel_values", "output_hidden_states", "return_dict"]
|
||||
self.assertListEqual(arg_names, expected_arg_names)
|
||||
else:
|
||||
expected_arg_names = [model.main_input_name]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_batching_equivalence(self):
|
||||
"""
|
||||
Tests that the model supports batching and that the output is the nearly the same for the same input in
|
||||
@@ -875,7 +762,7 @@ class ModelTesterMixin:
|
||||
|
||||
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
|
||||
if not self.model_tester.is_training:
|
||||
return
|
||||
self.skipTest(reason="ModelTester is not configured to run training tests")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if (
|
||||
@@ -914,7 +801,7 @@ class ModelTesterMixin:
|
||||
|
||||
def test_training(self):
|
||||
if not self.model_tester.is_training:
|
||||
return
|
||||
self.skipTest(reason="ModelTester is not configured to run training tests")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -1095,7 +982,7 @@ class ModelTesterMixin:
|
||||
|
||||
def _create_and_check_torchscript(self, config, inputs_dict):
|
||||
if not self.test_torchscript:
|
||||
return
|
||||
self.skipTest(reason="test_torchscript is set to `False`")
|
||||
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.torchscript = True
|
||||
@@ -1157,7 +1044,7 @@ class ModelTesterMixin:
|
||||
if "attention_mask" in inputs:
|
||||
trace_input["attention_mask"] = inputs["attention_mask"]
|
||||
else:
|
||||
self.skipTest("testing SDPA without attention_mask is not supported")
|
||||
self.skipTest(reason="testing SDPA without attention_mask is not supported")
|
||||
|
||||
model(main_input, attention_mask=inputs["attention_mask"])
|
||||
# example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
|
||||
@@ -1369,7 +1256,7 @@ class ModelTesterMixin:
|
||||
|
||||
def test_headmasking(self):
|
||||
if not self.test_head_masking:
|
||||
return
|
||||
self.skipTest(reason="Model does not support head masking")
|
||||
|
||||
global_rng.seed(42)
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -1439,7 +1326,7 @@ class ModelTesterMixin:
|
||||
|
||||
def test_head_pruning(self):
|
||||
if not self.test_pruning:
|
||||
return
|
||||
self.skipTest(reason="Pruning is not activated")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
(
|
||||
@@ -1472,7 +1359,7 @@ class ModelTesterMixin:
|
||||
|
||||
def test_head_pruning_save_load_from_pretrained(self):
|
||||
if not self.test_pruning:
|
||||
return
|
||||
self.skipTest(reason="Pruning is not activated")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
(
|
||||
@@ -1509,7 +1396,7 @@ class ModelTesterMixin:
|
||||
|
||||
def test_head_pruning_save_load_from_config_init(self):
|
||||
if not self.test_pruning:
|
||||
return
|
||||
self.skipTest(reason="Pruning is not activated")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
(
|
||||
@@ -1544,7 +1431,7 @@ class ModelTesterMixin:
|
||||
|
||||
def test_head_pruning_integration(self):
|
||||
if not self.test_pruning:
|
||||
return
|
||||
self.skipTest(reason="Pruning is not activated")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
(
|
||||
@@ -1733,7 +1620,7 @@ class ModelTesterMixin:
|
||||
|
||||
def test_resize_position_vector_embeddings(self):
|
||||
if not self.test_resize_position_embeddings:
|
||||
return
|
||||
self.skipTest(reason="Model does not have position embeddings")
|
||||
|
||||
(
|
||||
original_config,
|
||||
@@ -1816,7 +1703,7 @@ class ModelTesterMixin:
|
||||
inputs_dict,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.test_resize_embeddings:
|
||||
return
|
||||
self.skipTest(reason="test_resize_embeddings is set to `False`")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config = copy.deepcopy(original_config)
|
||||
@@ -1916,13 +1803,13 @@ class ModelTesterMixin:
|
||||
inputs_dict,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.test_resize_embeddings:
|
||||
return
|
||||
self.skipTest(reason="test_resize_embeddings is set to `False`")
|
||||
|
||||
original_config.tie_word_embeddings = False
|
||||
|
||||
# if model cannot untied embeddings -> leave test
|
||||
if original_config.tie_word_embeddings:
|
||||
return
|
||||
self.skipTest(reason="Model cannot untied embeddings")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config = copy.deepcopy(original_config)
|
||||
@@ -1994,7 +1881,7 @@ class ModelTesterMixin:
|
||||
|
||||
def test_correct_missing_keys(self):
|
||||
if not self.test_missing_keys:
|
||||
return
|
||||
self.skipTest(reason="test_missing_keys is set to `False`")
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
@@ -2022,7 +1909,7 @@ class ModelTesterMixin:
|
||||
|
||||
def test_tie_model_weights(self):
|
||||
if not self.test_torchscript:
|
||||
return
|
||||
self.skipTest(reason="test_torchscript is set to `False`")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -2481,8 +2368,7 @@ class ModelTesterMixin:
|
||||
|
||||
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning
|
||||
if not hasattr(transformers, tf_model_class_name):
|
||||
# transformers does not have this model in TF version yet
|
||||
return
|
||||
self.skipTest(reason="transformers does not have TF version of this model yet")
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
@@ -2664,8 +2550,7 @@ class ModelTesterMixin:
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
# no flax model exists for this class
|
||||
return
|
||||
self.skipTest(reason="No Flax model exists for this class")
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
@@ -2736,8 +2621,7 @@ class ModelTesterMixin:
|
||||
fx_model_class_name = "Flax" + model_class.__name__
|
||||
|
||||
if not hasattr(transformers, fx_model_class_name):
|
||||
# no flax model exists for this class
|
||||
return
|
||||
self.skipTest(reason="No Flax model exists for this class")
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
@@ -2849,7 +2733,7 @@ class ModelTesterMixin:
|
||||
|
||||
model_forward_args = inspect.signature(model.forward).parameters
|
||||
if "inputs_embeds" not in model_forward_args:
|
||||
self.skipTest("This model doesn't use `inputs_embeds`")
|
||||
self.skipTest(reason="This model doesn't use `inputs_embeds`")
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1
|
||||
@@ -2910,7 +2794,7 @@ class ModelTesterMixin:
|
||||
@require_torch_multi_gpu
|
||||
def test_model_parallelization(self):
|
||||
if not self.test_model_parallel:
|
||||
return
|
||||
self.skipTest(reason="test_model_parallel is set to False")
|
||||
|
||||
# a candidate for testing_utils
|
||||
def get_current_gpu_memory_use():
|
||||
@@ -2972,7 +2856,7 @@ class ModelTesterMixin:
|
||||
@require_torch_multi_gpu
|
||||
def test_model_parallel_equal_results(self):
|
||||
if not self.test_model_parallel:
|
||||
return
|
||||
self.skipTest(reason="test_model_parallel is set to False")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -3221,7 +3105,7 @@ class ModelTesterMixin:
|
||||
|
||||
def test_load_with_mismatched_shapes(self):
|
||||
if not self.test_mismatched_shapes:
|
||||
return
|
||||
self.skipTest(reason="test_missmatched_shapes is set to False")
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
@@ -3265,7 +3149,7 @@ class ModelTesterMixin:
|
||||
|
||||
def test_mismatched_shapes_have_properly_initialized_weights(self):
|
||||
if not self.test_mismatched_shapes:
|
||||
return
|
||||
self.skipTest(reason="test_missmatched_shapes is set to False")
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
@@ -3383,6 +3267,9 @@ class ModelTesterMixin:
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_conversion(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
@@ -3409,6 +3296,9 @@ class ModelTesterMixin:
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_flash_attn_2_inference_equivalence(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
@@ -3503,6 +3393,9 @@ class ModelTesterMixin:
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
@@ -3593,6 +3486,9 @@ class ModelTesterMixin:
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_flash_attn_2_generate_left_padding(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
@@ -3638,6 +3534,9 @@ class ModelTesterMixin:
|
||||
@is_flaky()
|
||||
@slow
|
||||
def test_flash_attn_2_generate_padding_right(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
@@ -3681,6 +3580,9 @@ class ModelTesterMixin:
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
if not self.all_model_classes[0]._supports_sdpa:
|
||||
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
|
||||
|
||||
@@ -3979,11 +3881,14 @@ class ModelTesterMixin:
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
major, _ = compute_capability
|
||||
|
||||
if not torch.version.cuda or major < 8:
|
||||
self.skipTest("This test requires an NVIDIA GPU with compute capability >= 8.0")
|
||||
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_sdpa:
|
||||
@@ -3992,13 +3897,15 @@ class ModelTesterMixin:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
if config.model_type in ["llava", "llava_next", "vipllava", "video_llava"]:
|
||||
self.skipTest("Llava-like models currently (transformers==4.39.1) requires an attention_mask input")
|
||||
self.skipTest(
|
||||
reason="Llava-like models currently (transformers==4.39.1) requires an attention_mask input"
|
||||
)
|
||||
if config.model_type in ["paligemma"]:
|
||||
self.skipTest(
|
||||
"PaliGemma-like models currently (transformers==4.41.0) requires an attention_mask input"
|
||||
)
|
||||
if config.model_type in ["idefics"]:
|
||||
self.skipTest("Idefics currently (transformers==4.39.1) requires an image_attention_mask input")
|
||||
self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input")
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@@ -4020,11 +3927,14 @@ class ModelTesterMixin:
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
major, _ = compute_capability
|
||||
|
||||
if not torch.version.cuda or major < 8:
|
||||
self.skipTest("This test requires an NVIDIA GPU with compute capability >= 8.0")
|
||||
self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_sdpa:
|
||||
@@ -4060,6 +3970,9 @@ class ModelTesterMixin:
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
def test_eager_matches_sdpa_generate(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
max_new_tokens = 30
|
||||
|
||||
if len(self.all_generative_model_classes) == 0:
|
||||
@@ -4130,6 +4043,9 @@ class ModelTesterMixin:
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_matches_eager_sliding_window(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
WINDOW_ATTENTION_MODELS = ["mistral", "mixtral", "qwen2", "qwen_moe", "starcoder2"]
|
||||
|
||||
if len(self.all_generative_model_classes) == 0:
|
||||
@@ -4184,6 +4100,9 @@ class ModelTesterMixin:
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_generate_use_cache(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
max_new_tokens = 30
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
@@ -4229,6 +4148,9 @@ class ModelTesterMixin:
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
@@ -4284,8 +4206,7 @@ class ModelTesterMixin:
|
||||
|
||||
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning
|
||||
if not hasattr(transformers, tf_model_class_name):
|
||||
# transformers does not have this model in TF version yet
|
||||
return
|
||||
self.skipTest(reason="transformers does not have this model in TF version yet")
|
||||
|
||||
tf_model_class = getattr(transformers, tf_model_class_name)
|
||||
|
||||
@@ -4309,8 +4230,7 @@ class ModelTesterMixin:
|
||||
|
||||
flax_model_class_name = "Flax" + model_class.__name__ # Add the "Flax at the beginning
|
||||
if not hasattr(transformers, flax_model_class_name):
|
||||
# transformers does not have this model in Flax version yet
|
||||
return
|
||||
self.skipTest(reason="transformers does not have this model in Flax version yet")
|
||||
|
||||
flax_model_class = getattr(transformers, flax_model_class_name)
|
||||
|
||||
@@ -4331,6 +4251,9 @@ class ModelTesterMixin:
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_from_config(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
@@ -4407,8 +4330,13 @@ class ModelTesterMixin:
|
||||
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
|
||||
|
||||
def test_custom_4d_attention_mask(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
if len(self.all_generative_model_classes) == 0:
|
||||
self.skipTest("Model architecture has no generative classes, and thus not necessarily supporting 4D masks")
|
||||
self.skipTest(
|
||||
reason="Model architecture has no generative classes, and thus not necessarily supporting 4D masks"
|
||||
)
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_static_cache:
|
||||
@@ -4453,7 +4381,7 @@ class ModelTesterMixin:
|
||||
@require_read_token
|
||||
def test_torch_compile(self):
|
||||
if version.parse(torch.__version__) < version.parse("2.3"):
|
||||
self.skipTest("This test requires torch >= 2.3 to run.")
|
||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||
|
||||
if not hasattr(self, "_torch_compile_test_ckpt"):
|
||||
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_test_ckpt`.")
|
||||
|
||||
Reference in New Issue
Block a user