diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 3e47170e5d..dacf4cd752 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -32,7 +32,7 @@ from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model logger = logging.getLogger(__name__) -class TFModelUtils: +class TFModelUtilsMixin: """ A few utilities for `tf.keras.Model`s, to be used as a mixin. """ @@ -47,7 +47,7 @@ class TFModelUtils: return self.count_params() -class TFPreTrainedModel(tf.keras.Model, TFModelUtils): +class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): r""" Base class for all TF models. :class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f71c81b4a1..9543f3bdc0 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -53,7 +53,7 @@ except ImportError: return input -class ModuleUtils: +class ModuleUtilsMixin: """ A few utilities for torch.nn.Modules, to be used as a mixin. """ @@ -66,7 +66,7 @@ class ModuleUtils: return sum(p.numel() for p in params) -class PreTrainedModel(nn.Module, ModuleUtils): +class PreTrainedModel(nn.Module, ModuleUtilsMixin): r""" Base class for all models. :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f5f93c7f07..77fb7252be 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -592,7 +592,6 @@ class ModelTesterMixin: model(**inputs_dict) - global_rng = random.Random()