From 209bec463637182e5d7a36787d2901a5dcc24136 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 12 Oct 2022 14:00:17 -0400 Subject: [PATCH] Add a decorator for flaky tests (#19498) * Add a decorator for flaky tests * Quality * Don't break the rest * Address review comments * Fix test name * Fix typo and print to stderr --- src/transformers/testing_utils.py | 37 ++++++++++++++++++- .../test_modeling_time_series_transformer.py | 6 ++- .../wav2vec2/test_modeling_flax_wav2vec2.py | 7 ++++ .../wav2vec2/test_modeling_tf_wav2vec2.py | 3 +- 4 files changed, 50 insertions(+), 3 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 5884e642d9..ea4c6c60e5 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -14,6 +14,7 @@ import collections import contextlib +import functools import inspect import logging import os @@ -23,12 +24,13 @@ import shutil import subprocess import sys import tempfile +import time import unittest from collections.abc import Mapping from distutils.util import strtobool from io import StringIO from pathlib import Path -from typing import Iterator, List, Union +from typing import Iterator, List, Optional, Union from unittest import mock import huggingface_hub @@ -1635,3 +1637,36 @@ class RequestCounter: self.other_request_count += 1 return self.old_request(method=method, **kwargs) + + +def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None): + """ + To decorate flaky tests. They will be retried on failures. + + Args: + max_attempts (`int`, *optional*, defaults to 5): + The maximum number of attempts to retry the flaky test. + wait_before_retry (`float`, *optional*): + If provided, will wait that number of seconds before retrying the test. + """ + + def decorator(test_func_ref): + @functools.wraps(test_func_ref) + def wrapper(*args, **kwargs): + retry_count = 1 + + while retry_count < max_attempts: + try: + return test_func_ref(*args, **kwargs) + + except Exception as err: + print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr) + if wait_before_retry is not None: + time.sleep(wait_before_retry) + retry_count += 1 + + return test_func_ref(*args, **kwargs) + + return wrapper + + return decorator diff --git a/tests/models/time_series_transformer/test_modeling_time_series_transformer.py b/tests/models/time_series_transformer/test_modeling_time_series_transformer.py index d513f1fe21..a3973a39ed 100644 --- a/tests/models/time_series_transformer/test_modeling_time_series_transformer.py +++ b/tests/models/time_series_transformer/test_modeling_time_series_transformer.py @@ -20,7 +20,7 @@ import unittest from huggingface_hub import hf_hub_download from transformers import is_torch_available -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import is_flaky, require_torch, slow, torch_device from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor @@ -359,6 +359,10 @@ class TimeSeriesTransformerModelTest(ModelTesterMixin, unittest.TestCase): [self.model_tester.num_attention_heads, encoder_seq_length, encoder_seq_length], ) + @is_flaky() + def test_retain_grad_hidden_states_attentions(self): + super().test_retain_grad_hidden_states_attentions() + def prepare_batch(filename="train-batch.pt"): file = hf_hub_download(repo_id="kashif/tourism-monthly-batch", filename=filename, repo_type="dataset") diff --git a/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py b/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py index b74e271c02..aa6781a42e 100644 --- a/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_flax_wav2vec2.py @@ -21,7 +21,9 @@ from datasets import load_dataset from transformers import Wav2Vec2Config, is_flax_available from transformers.testing_utils import ( + is_flaky, is_librosa_available, + is_pt_flax_cross_test, is_pyctcdecode_available, require_flax, require_librosa, @@ -302,6 +304,11 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase): outputs = model(np.ones((1, 1024), dtype="f4")) self.assertIsNotNone(outputs) + @is_pt_flax_cross_test + @is_flaky() + def test_equivalence_pt_to_flax(self): + super().test_equivalence_pt_to_flax() + @require_flax class FlaxWav2Vec2UtilsTest(unittest.TestCase): diff --git a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py index 665bf2d4f2..6ea1919a33 100644 --- a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py @@ -26,7 +26,7 @@ from datasets import load_dataset from huggingface_hub import snapshot_download from transformers import Wav2Vec2Config, is_tf_available -from transformers.testing_utils import require_librosa, require_pyctcdecode, require_tf, slow +from transformers.testing_utils import is_flaky, require_librosa, require_pyctcdecode, require_tf, slow from transformers.utils import is_librosa_available, is_pyctcdecode_available from ...test_configuration_common import ConfigTester @@ -309,6 +309,7 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_ctc_loss(*config_and_inputs) + @is_flaky() def test_labels_out_of_vocab(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_labels_out_of_vocab(*config_and_inputs)