From 41925e42135257361b7f02aa20e3bbdab3f7b923 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Tue, 25 Feb 2025 14:53:11 -0500 Subject: [PATCH] Add retry hf hub decorator (#35213) * Add retry torch decorator * New approach * Empty commit * Empty commit * Style * Use logger.error * Add a test * Update src/transformers/testing_utils.py Co-authored-by: Lucain * Fix err * Update tests/utils/test_modeling_utils.py --------- Co-authored-by: Lucain Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> --- src/transformers/testing_utils.py | 47 +++++++++++++++++++++++++++++- tests/test_modeling_common.py | 11 +++++++ tests/utils/test_modeling_utils.py | 13 +++++++++ 3 files changed, 70 insertions(+), 1 deletion(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index bb0b3d3b2f..17223278eb 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -43,6 +43,7 @@ from unittest import mock from unittest.mock import patch import huggingface_hub.utils +import requests import urllib3 from huggingface_hub import delete_repo from packaging import version @@ -200,6 +201,8 @@ else: IS_ROCM_SYSTEM = False IS_CUDA_SYSTEM = False +logger = transformers_logging.get_logger(__name__) + def parse_flag_from_env(key, default=False): try: @@ -2497,7 +2500,49 @@ def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, d return test_func_ref(*args, **kwargs) except Exception as err: - print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr) + logger.error(f"Test failed with {err} at try {retry_count}/{max_attempts}.") + if wait_before_retry is not None: + time.sleep(wait_before_retry) + retry_count += 1 + + return test_func_ref(*args, **kwargs) + + return wrapper + + return decorator + + +def hub_retry(max_attempts: int = 5, wait_before_retry: Optional[float] = 2): + """ + To decorate tests that download from the Hub. They can fail due to a + variety of network issues such as timeouts, connection resets, etc. + + Args: + max_attempts (`int`, *optional*, defaults to 5): + The maximum number of attempts to retry the flaky test. + wait_before_retry (`float`, *optional*, defaults to 2): + If provided, will wait that number of seconds before retrying the test. + """ + + def decorator(test_func_ref): + @functools.wraps(test_func_ref) + def wrapper(*args, **kwargs): + retry_count = 1 + + while retry_count < max_attempts: + try: + return test_func_ref(*args, **kwargs) + # We catch all exceptions related to network issues from requests + except ( + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + requests.exceptions.ReadTimeout, + requests.exceptions.HTTPError, + requests.exceptions.RequestException, + ) as err: + logger.error( + f"Test failed with {err} at try {retry_count}/{max_attempts} as it couldn't connect to the specied Hub repository." + ) if wait_before_retry is not None: time.sleep(wait_before_retry) retry_count += 1 diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 80e7bd1447..8de8b0584c 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -74,6 +74,7 @@ from transformers.models.auto.modeling_auto import ( ) from transformers.testing_utils import ( CaptureLogger, + hub_retry, is_flaky, require_accelerate, require_bitsandbytes, @@ -214,6 +215,16 @@ class ModelTesterMixin: _is_composite = False model_split_percents = [0.5, 0.7, 0.9] + # Note: for all mixins that utilize the Hub in some way, we should ensure that + # they contain the `hub_retry` decorator in case of failures. + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + for attr_name in dir(cls): + if attr_name.startswith("test_"): + attr = getattr(cls, attr_name) + if callable(attr): + setattr(cls, attr_name, hub_retry(attr)) + @property def all_generative_model_classes(self): return tuple(model_class for model_class in self.all_model_classes if model_class.can_generate()) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 72434a1922..7d8906fa59 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -51,6 +51,7 @@ from transformers.testing_utils import ( LoggingLevel, TemporaryHubRepo, TestCasePlus, + hub_retry, is_staging_test, require_accelerate, require_flax, @@ -327,6 +328,18 @@ class ModelUtilsTest(TestCasePlus): torch.set_default_dtype(self.old_dtype) super().tearDown() + def test_hub_retry(self): + @hub_retry(max_attempts=2) + def test_func(): + # First attempt will fail with a connection error + if not hasattr(test_func, "attempt"): + test_func.attempt = 1 + raise requests.exceptions.ConnectionError("Connection failed") + # Second attempt will succeed + return True + + self.assertTrue(test_func()) + @slow def test_model_from_pretrained(self): model_name = "google-bert/bert-base-uncased"