From bbc00046b9b746d5ea69a27a58e652fdf617c91c Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 10 Jan 2025 15:40:04 +0100 Subject: [PATCH] Fix flaky `test_custom_4d_attention_mask` (#35606) * fix * fix --------- Co-authored-by: ydshieh --- src/transformers/testing_utils.py | 10 ++++++++-- tests/test_modeling_common.py | 4 ++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 7876b22a2b..89587d303e 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -1431,14 +1431,20 @@ def set_model_tester_for_less_flaky_test(test_case): and target_num_hidden_layers is not None ): test_case.model_tester.vision_config = copy.deepcopy(test_case.model_tester.vision_config) - test_case.model_tester.vision_config["num_hidden_layers"] = target_num_hidden_layers + if isinstance(test_case.model_tester.vision_config, dict): + test_case.model_tester.vision_config["num_hidden_layers"] = 1 + else: + test_case.model_tester.vision_config.num_hidden_layers = 1 if ( hasattr(test_case.model_tester, "text_config") and "num_hidden_layers" in test_case.model_tester.text_config and target_num_hidden_layers is not None ): test_case.model_tester.text_config = copy.deepcopy(test_case.model_tester.text_config) - test_case.model_tester.text_config["num_hidden_layers"] = target_num_hidden_layers + if isinstance(test_case.model_tester.text_config, dict): + test_case.model_tester.text_config["num_hidden_layers"] = 1 + else: + test_case.model_tester.text_config.num_hidden_layers = 1 # A few model class specific handling diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 3c3c6684c2..6a9b8523f9 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4707,13 +4707,17 @@ class ModelTesterMixin: reason="Model architecture has no generative classes, and thus not necessarily supporting 4D masks" ) + set_model_tester_for_less_flaky_test(self) + for model_class in self.all_generative_model_classes: if not model_class._supports_static_cache: self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks") config, _ = self.model_tester.prepare_config_and_inputs_for_common() + set_config_for_less_flaky_test(config) if getattr(config, "sliding_window", 0) is not None and getattr(config, "sliding_window", 0) > 0: self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test") model = model_class(config).to(device=torch_device, dtype=torch.float32) + set_model_for_less_flaky_test(model) ( input_ids,