Add retry hf hub decorator (#35213)

* Add retry torch decorator

* New approach

* Empty commit

* Empty commit

* Style

* Use logger.error

* Add a test

* Update src/transformers/testing_utils.py

Co-authored-by: Lucain <lucainp@gmail.com>

* Fix err

* Update tests/utils/test_modeling_utils.py

---------

Co-authored-by: Lucain <lucainp@gmail.com>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
Zach Mueller
2025-02-25 14:53:11 -05:00
committed by GitHub
parent 9ebfda3263
commit 41925e4213
3 changed files with 70 additions and 1 deletions

View File

@@ -43,6 +43,7 @@ from unittest import mock
from unittest.mock import patch from unittest.mock import patch
import huggingface_hub.utils import huggingface_hub.utils
import requests
import urllib3 import urllib3
from huggingface_hub import delete_repo from huggingface_hub import delete_repo
from packaging import version from packaging import version
@@ -200,6 +201,8 @@ else:
IS_ROCM_SYSTEM = False IS_ROCM_SYSTEM = False
IS_CUDA_SYSTEM = False IS_CUDA_SYSTEM = False
logger = transformers_logging.get_logger(__name__)
def parse_flag_from_env(key, default=False): def parse_flag_from_env(key, default=False):
try: try:
@@ -2497,7 +2500,49 @@ def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, d
return test_func_ref(*args, **kwargs) return test_func_ref(*args, **kwargs)
except Exception as err: except Exception as err:
print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr) logger.error(f"Test failed with {err} at try {retry_count}/{max_attempts}.")
if wait_before_retry is not None:
time.sleep(wait_before_retry)
retry_count += 1
return test_func_ref(*args, **kwargs)
return wrapper
return decorator
def hub_retry(max_attempts: int = 5, wait_before_retry: Optional[float] = 2):
"""
To decorate tests that download from the Hub. They can fail due to a
variety of network issues such as timeouts, connection resets, etc.
Args:
max_attempts (`int`, *optional*, defaults to 5):
The maximum number of attempts to retry the flaky test.
wait_before_retry (`float`, *optional*, defaults to 2):
If provided, will wait that number of seconds before retrying the test.
"""
def decorator(test_func_ref):
@functools.wraps(test_func_ref)
def wrapper(*args, **kwargs):
retry_count = 1
while retry_count < max_attempts:
try:
return test_func_ref(*args, **kwargs)
# We catch all exceptions related to network issues from requests
except (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.ReadTimeout,
requests.exceptions.HTTPError,
requests.exceptions.RequestException,
) as err:
logger.error(
f"Test failed with {err} at try {retry_count}/{max_attempts} as it couldn't connect to the specied Hub repository."
)
if wait_before_retry is not None: if wait_before_retry is not None:
time.sleep(wait_before_retry) time.sleep(wait_before_retry)
retry_count += 1 retry_count += 1

View File

@@ -74,6 +74,7 @@ from transformers.models.auto.modeling_auto import (
) )
from transformers.testing_utils import ( from transformers.testing_utils import (
CaptureLogger, CaptureLogger,
hub_retry,
is_flaky, is_flaky,
require_accelerate, require_accelerate,
require_bitsandbytes, require_bitsandbytes,
@@ -214,6 +215,16 @@ class ModelTesterMixin:
_is_composite = False _is_composite = False
model_split_percents = [0.5, 0.7, 0.9] model_split_percents = [0.5, 0.7, 0.9]
# Note: for all mixins that utilize the Hub in some way, we should ensure that
# they contain the `hub_retry` decorator in case of failures.
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
for attr_name in dir(cls):
if attr_name.startswith("test_"):
attr = getattr(cls, attr_name)
if callable(attr):
setattr(cls, attr_name, hub_retry(attr))
@property @property
def all_generative_model_classes(self): def all_generative_model_classes(self):
return tuple(model_class for model_class in self.all_model_classes if model_class.can_generate()) return tuple(model_class for model_class in self.all_model_classes if model_class.can_generate())

View File

@@ -51,6 +51,7 @@ from transformers.testing_utils import (
LoggingLevel, LoggingLevel,
TemporaryHubRepo, TemporaryHubRepo,
TestCasePlus, TestCasePlus,
hub_retry,
is_staging_test, is_staging_test,
require_accelerate, require_accelerate,
require_flax, require_flax,
@@ -327,6 +328,18 @@ class ModelUtilsTest(TestCasePlus):
torch.set_default_dtype(self.old_dtype) torch.set_default_dtype(self.old_dtype)
super().tearDown() super().tearDown()
def test_hub_retry(self):
@hub_retry(max_attempts=2)
def test_func():
# First attempt will fail with a connection error
if not hasattr(test_func, "attempt"):
test_func.attempt = 1
raise requests.exceptions.ConnectionError("Connection failed")
# Second attempt will succeed
return True
self.assertTrue(test_func())
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
model_name = "google-bert/bert-base-uncased" model_name = "google-bert/bert-base-uncased"