diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 2366f4dfc6..ac1772d853 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -49,6 +49,7 @@ from transformers.testing_utils import ( USER, CaptureLogger, TestCasePlus, + is_flaky, is_pt_flax_cross_test, is_pt_tf_cross_test, is_staging_test, @@ -340,6 +341,7 @@ class ModelTesterMixin: if hasattr(module, "bias") and module.bias is not None: module.bias.data.fill_(3) + @is_flaky() def test_save_load_fast_init_from_base(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() base_class = MODEL_MAPPING[config.__class__]