Add a decorator for flaky tests (#19498)
* Add a decorator for flaky tests * Quality * Don't break the rest * Address review comments * Fix test name * Fix typo and print to stderr
This commit is contained in:
@@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
import collections
|
import collections
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -23,12 +24,13 @@ import shutil
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from distutils.util import strtobool
|
from distutils.util import strtobool
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterator, List, Union
|
from typing import Iterator, List, Optional, Union
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
@@ -1635,3 +1637,36 @@ class RequestCounter:
|
|||||||
self.other_request_count += 1
|
self.other_request_count += 1
|
||||||
|
|
||||||
return self.old_request(method=method, **kwargs)
|
return self.old_request(method=method, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None):
|
||||||
|
"""
|
||||||
|
To decorate flaky tests. They will be retried on failures.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_attempts (`int`, *optional*, defaults to 5):
|
||||||
|
The maximum number of attempts to retry the flaky test.
|
||||||
|
wait_before_retry (`float`, *optional*):
|
||||||
|
If provided, will wait that number of seconds before retrying the test.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(test_func_ref):
|
||||||
|
@functools.wraps(test_func_ref)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
retry_count = 1
|
||||||
|
|
||||||
|
while retry_count < max_attempts:
|
||||||
|
try:
|
||||||
|
return test_func_ref(*args, **kwargs)
|
||||||
|
|
||||||
|
except Exception as err:
|
||||||
|
print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr)
|
||||||
|
if wait_before_retry is not None:
|
||||||
|
time.sleep(wait_before_retry)
|
||||||
|
retry_count += 1
|
||||||
|
|
||||||
|
return test_func_ref(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import unittest
|
|||||||
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||||
@@ -359,6 +359,10 @@ class TimeSeriesTransformerModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_seq_length],
|
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_seq_length],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@is_flaky()
|
||||||
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
|
super().test_retain_grad_hidden_states_attentions()
|
||||||
|
|
||||||
|
|
||||||
def prepare_batch(filename="train-batch.pt"):
|
def prepare_batch(filename="train-batch.pt"):
|
||||||
file = hf_hub_download(repo_id="kashif/tourism-monthly-batch", filename=filename, repo_type="dataset")
|
file = hf_hub_download(repo_id="kashif/tourism-monthly-batch", filename=filename, repo_type="dataset")
|
||||||
|
|||||||
@@ -21,7 +21,9 @@ from datasets import load_dataset
|
|||||||
|
|
||||||
from transformers import Wav2Vec2Config, is_flax_available
|
from transformers import Wav2Vec2Config, is_flax_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
is_flaky,
|
||||||
is_librosa_available,
|
is_librosa_available,
|
||||||
|
is_pt_flax_cross_test,
|
||||||
is_pyctcdecode_available,
|
is_pyctcdecode_available,
|
||||||
require_flax,
|
require_flax,
|
||||||
require_librosa,
|
require_librosa,
|
||||||
@@ -302,6 +304,11 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
outputs = model(np.ones((1, 1024), dtype="f4"))
|
outputs = model(np.ones((1, 1024), dtype="f4"))
|
||||||
self.assertIsNotNone(outputs)
|
self.assertIsNotNone(outputs)
|
||||||
|
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
@is_flaky()
|
||||||
|
def test_equivalence_pt_to_flax(self):
|
||||||
|
super().test_equivalence_pt_to_flax()
|
||||||
|
|
||||||
|
|
||||||
@require_flax
|
@require_flax
|
||||||
class FlaxWav2Vec2UtilsTest(unittest.TestCase):
|
class FlaxWav2Vec2UtilsTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from datasets import load_dataset
|
|||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from transformers import Wav2Vec2Config, is_tf_available
|
from transformers import Wav2Vec2Config, is_tf_available
|
||||||
from transformers.testing_utils import require_librosa, require_pyctcdecode, require_tf, slow
|
from transformers.testing_utils import is_flaky, require_librosa, require_pyctcdecode, require_tf, slow
|
||||||
from transformers.utils import is_librosa_available, is_pyctcdecode_available
|
from transformers.utils import is_librosa_available, is_pyctcdecode_available
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@@ -309,6 +309,7 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_ctc_loss(*config_and_inputs)
|
self.model_tester.check_ctc_loss(*config_and_inputs)
|
||||||
|
|
||||||
|
@is_flaky()
|
||||||
def test_labels_out_of_vocab(self):
|
def test_labels_out_of_vocab(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user