diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 8de8b0584c..bb6f14ce3d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -223,7 +223,7 @@ class ModelTesterMixin: if attr_name.startswith("test_"): attr = getattr(cls, attr_name) if callable(attr): - setattr(cls, attr_name, hub_retry(attr)) + setattr(cls, attr_name, hub_retry()(attr)) @property def all_generative_model_classes(self):