Add retry hf hub decorator (#35213)
* Add retry torch decorator * New approach * Empty commit * Empty commit * Style * Use logger.error * Add a test * Update src/transformers/testing_utils.py Co-authored-by: Lucain <lucainp@gmail.com> * Fix err * Update tests/utils/test_modeling_utils.py --------- Co-authored-by: Lucain <lucainp@gmail.com> Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -51,6 +51,7 @@ from transformers.testing_utils import (
|
||||
LoggingLevel,
|
||||
TemporaryHubRepo,
|
||||
TestCasePlus,
|
||||
hub_retry,
|
||||
is_staging_test,
|
||||
require_accelerate,
|
||||
require_flax,
|
||||
@@ -327,6 +328,18 @@ class ModelUtilsTest(TestCasePlus):
|
||||
torch.set_default_dtype(self.old_dtype)
|
||||
super().tearDown()
|
||||
|
||||
def test_hub_retry(self):
|
||||
@hub_retry(max_attempts=2)
|
||||
def test_func():
|
||||
# First attempt will fail with a connection error
|
||||
if not hasattr(test_func, "attempt"):
|
||||
test_func.attempt = 1
|
||||
raise requests.exceptions.ConnectionError("Connection failed")
|
||||
# Second attempt will succeed
|
||||
return True
|
||||
|
||||
self.assertTrue(test_func())
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model_name = "google-bert/bert-base-uncased"
|
||||
|
||||
Reference in New Issue
Block a user