Add retry hf hub decorator (#35213)

* Add retry torch decorator

* New approach

* Empty commit

* Empty commit

* Style

* Use logger.error

* Add a test

* Update src/transformers/testing_utils.py

Co-authored-by: Lucain <lucainp@gmail.com>

* Fix err

* Update tests/utils/test_modeling_utils.py

---------

Co-authored-by: Lucain <lucainp@gmail.com>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
Zach Mueller
2025-02-25 14:53:11 -05:00
committed by GitHub
parent 9ebfda3263
commit 41925e4213
3 changed files with 70 additions and 1 deletions

View File

@@ -74,6 +74,7 @@ from transformers.models.auto.modeling_auto import (
)
from transformers.testing_utils import (
CaptureLogger,
hub_retry,
is_flaky,
require_accelerate,
require_bitsandbytes,
@@ -214,6 +215,16 @@ class ModelTesterMixin:
_is_composite = False
model_split_percents = [0.5, 0.7, 0.9]
# Note: for all mixins that utilize the Hub in some way, we should ensure that
# they contain the `hub_retry` decorator in case of failures.
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
for attr_name in dir(cls):
if attr_name.startswith("test_"):
attr = getattr(cls, attr_name)
if callable(attr):
setattr(cls, attr_name, hub_retry(attr))
@property
def all_generative_model_classes(self):
return tuple(model_class for model_class in self.all_model_classes if model_class.can_generate())

View File

@@ -51,6 +51,7 @@ from transformers.testing_utils import (
LoggingLevel,
TemporaryHubRepo,
TestCasePlus,
hub_retry,
is_staging_test,
require_accelerate,
require_flax,
@@ -327,6 +328,18 @@ class ModelUtilsTest(TestCasePlus):
torch.set_default_dtype(self.old_dtype)
super().tearDown()
def test_hub_retry(self):
@hub_retry(max_attempts=2)
def test_func():
# First attempt will fail with a connection error
if not hasattr(test_func, "attempt"):
test_func.attempt = 1
raise requests.exceptions.ConnectionError("Connection failed")
# Second attempt will succeed
return True
self.assertTrue(test_func())
@slow
def test_model_from_pretrained(self):
model_name = "google-bert/bert-base-uncased"