Add retry hf hub decorator (#35213)
* Add retry torch decorator * New approach * Empty commit * Empty commit * Style * Use logger.error * Add a test * Update src/transformers/testing_utils.py Co-authored-by: Lucain <lucainp@gmail.com> * Fix err * Update tests/utils/test_modeling_utils.py --------- Co-authored-by: Lucain <lucainp@gmail.com> Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -74,6 +74,7 @@ from transformers.models.auto.modeling_auto import (
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
hub_retry,
|
||||
is_flaky,
|
||||
require_accelerate,
|
||||
require_bitsandbytes,
|
||||
@@ -214,6 +215,16 @@ class ModelTesterMixin:
|
||||
_is_composite = False
|
||||
model_split_percents = [0.5, 0.7, 0.9]
|
||||
|
||||
# Note: for all mixins that utilize the Hub in some way, we should ensure that
|
||||
# they contain the `hub_retry` decorator in case of failures.
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
for attr_name in dir(cls):
|
||||
if attr_name.startswith("test_"):
|
||||
attr = getattr(cls, attr_name)
|
||||
if callable(attr):
|
||||
setattr(cls, attr_name, hub_retry(attr))
|
||||
|
||||
@property
|
||||
def all_generative_model_classes(self):
|
||||
return tuple(model_class for model_class in self.all_model_classes if model_class.can_generate())
|
||||
|
||||
Reference in New Issue
Block a user