Add retry hf hub decorator (#35213)
* Add retry torch decorator * New approach * Empty commit * Empty commit * Style * Use logger.error * Add a test * Update src/transformers/testing_utils.py Co-authored-by: Lucain <lucainp@gmail.com> * Fix err * Update tests/utils/test_modeling_utils.py --------- Co-authored-by: Lucain <lucainp@gmail.com> Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -43,6 +43,7 @@ from unittest import mock
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import huggingface_hub.utils
|
import huggingface_hub.utils
|
||||||
|
import requests
|
||||||
import urllib3
|
import urllib3
|
||||||
from huggingface_hub import delete_repo
|
from huggingface_hub import delete_repo
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@@ -200,6 +201,8 @@ else:
|
|||||||
IS_ROCM_SYSTEM = False
|
IS_ROCM_SYSTEM = False
|
||||||
IS_CUDA_SYSTEM = False
|
IS_CUDA_SYSTEM = False
|
||||||
|
|
||||||
|
logger = transformers_logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def parse_flag_from_env(key, default=False):
|
def parse_flag_from_env(key, default=False):
|
||||||
try:
|
try:
|
||||||
@@ -2497,7 +2500,49 @@ def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, d
|
|||||||
return test_func_ref(*args, **kwargs)
|
return test_func_ref(*args, **kwargs)
|
||||||
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr)
|
logger.error(f"Test failed with {err} at try {retry_count}/{max_attempts}.")
|
||||||
|
if wait_before_retry is not None:
|
||||||
|
time.sleep(wait_before_retry)
|
||||||
|
retry_count += 1
|
||||||
|
|
||||||
|
return test_func_ref(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def hub_retry(max_attempts: int = 5, wait_before_retry: Optional[float] = 2):
|
||||||
|
"""
|
||||||
|
To decorate tests that download from the Hub. They can fail due to a
|
||||||
|
variety of network issues such as timeouts, connection resets, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_attempts (`int`, *optional*, defaults to 5):
|
||||||
|
The maximum number of attempts to retry the flaky test.
|
||||||
|
wait_before_retry (`float`, *optional*, defaults to 2):
|
||||||
|
If provided, will wait that number of seconds before retrying the test.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(test_func_ref):
|
||||||
|
@functools.wraps(test_func_ref)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
retry_count = 1
|
||||||
|
|
||||||
|
while retry_count < max_attempts:
|
||||||
|
try:
|
||||||
|
return test_func_ref(*args, **kwargs)
|
||||||
|
# We catch all exceptions related to network issues from requests
|
||||||
|
except (
|
||||||
|
requests.exceptions.ConnectionError,
|
||||||
|
requests.exceptions.Timeout,
|
||||||
|
requests.exceptions.ReadTimeout,
|
||||||
|
requests.exceptions.HTTPError,
|
||||||
|
requests.exceptions.RequestException,
|
||||||
|
) as err:
|
||||||
|
logger.error(
|
||||||
|
f"Test failed with {err} at try {retry_count}/{max_attempts} as it couldn't connect to the specied Hub repository."
|
||||||
|
)
|
||||||
if wait_before_retry is not None:
|
if wait_before_retry is not None:
|
||||||
time.sleep(wait_before_retry)
|
time.sleep(wait_before_retry)
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
|
|||||||
@@ -74,6 +74,7 @@ from transformers.models.auto.modeling_auto import (
|
|||||||
)
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
CaptureLogger,
|
CaptureLogger,
|
||||||
|
hub_retry,
|
||||||
is_flaky,
|
is_flaky,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
@@ -214,6 +215,16 @@ class ModelTesterMixin:
|
|||||||
_is_composite = False
|
_is_composite = False
|
||||||
model_split_percents = [0.5, 0.7, 0.9]
|
model_split_percents = [0.5, 0.7, 0.9]
|
||||||
|
|
||||||
|
# Note: for all mixins that utilize the Hub in some way, we should ensure that
|
||||||
|
# they contain the `hub_retry` decorator in case of failures.
|
||||||
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
for attr_name in dir(cls):
|
||||||
|
if attr_name.startswith("test_"):
|
||||||
|
attr = getattr(cls, attr_name)
|
||||||
|
if callable(attr):
|
||||||
|
setattr(cls, attr_name, hub_retry(attr))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def all_generative_model_classes(self):
|
def all_generative_model_classes(self):
|
||||||
return tuple(model_class for model_class in self.all_model_classes if model_class.can_generate())
|
return tuple(model_class for model_class in self.all_model_classes if model_class.can_generate())
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ from transformers.testing_utils import (
|
|||||||
LoggingLevel,
|
LoggingLevel,
|
||||||
TemporaryHubRepo,
|
TemporaryHubRepo,
|
||||||
TestCasePlus,
|
TestCasePlus,
|
||||||
|
hub_retry,
|
||||||
is_staging_test,
|
is_staging_test,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_flax,
|
require_flax,
|
||||||
@@ -327,6 +328,18 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
torch.set_default_dtype(self.old_dtype)
|
torch.set_default_dtype(self.old_dtype)
|
||||||
super().tearDown()
|
super().tearDown()
|
||||||
|
|
||||||
|
def test_hub_retry(self):
|
||||||
|
@hub_retry(max_attempts=2)
|
||||||
|
def test_func():
|
||||||
|
# First attempt will fail with a connection error
|
||||||
|
if not hasattr(test_func, "attempt"):
|
||||||
|
test_func.attempt = 1
|
||||||
|
raise requests.exceptions.ConnectionError("Connection failed")
|
||||||
|
# Second attempt will succeed
|
||||||
|
return True
|
||||||
|
|
||||||
|
self.assertTrue(test_func())
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
model_name = "google-bert/bert-base-uncased"
|
model_name = "google-bert/bert-base-uncased"
|
||||||
|
|||||||
Reference in New Issue
Block a user