From 84c0aa18688581e34c31322a272b6beac6d00938 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 10 Jan 2020 17:40:02 +0000 Subject: [PATCH] num_parameters helper --- src/transformers/modeling_tf_utils.py | 18 +++++++++++++++++- src/transformers/modeling_utils.py | 15 ++++++++++++++- tests/test_modeling_auto.py | 2 ++ tests/test_modeling_tf_auto.py | 2 ++ 4 files changed, 35 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index df91ad011e..3e47170e5d 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -20,6 +20,7 @@ import logging import os import h5py +import numpy as np import tensorflow as tf from tensorflow.python.keras.saving import hdf5_format @@ -31,7 +32,22 @@ from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model logger = logging.getLogger(__name__) -class TFPreTrainedModel(tf.keras.Model): +class TFModelUtils: + """ + A few utilities for `tf.keras.Model`s, to be used as a mixin. + """ + + def num_parameters(self, only_trainable: bool = False) -> int: + """ + Get number of (optionally, trainable) parameters in the model. + """ + if only_trainable: + return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables)) + else: + return self.count_params() + + +class TFPreTrainedModel(tf.keras.Model, TFModelUtils): r""" Base class for all TF models. :class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 115494f2f5..f71c81b4a1 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -53,7 +53,20 @@ except ImportError: return input -class PreTrainedModel(nn.Module): +class ModuleUtils: + """ + A few utilities for torch.nn.Modules, to be used as a mixin. + """ + + def num_parameters(self, only_trainable: bool = False) -> int: + """ + Get number of (optionally, trainable) parameters in the module. + """ + params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters() + return sum(p.numel() for p in params) + + +class PreTrainedModel(nn.Module, ModuleUtils): r""" Base class for all models. :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models diff --git a/tests/test_modeling_auto.py b/tests/test_modeling_auto.py index 1819415e6a..c91e3cd3ed 100644 --- a/tests/test_modeling_auto.py +++ b/tests/test_modeling_auto.py @@ -100,3 +100,5 @@ class AutoModelTest(unittest.TestCase): logging.basicConfig(level=logging.INFO) model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER) self.assertIsInstance(model, BertForMaskedLM) + self.assertEqual(model.num_parameters(), 14830) + self.assertEqual(model.num_parameters(only_trainable=True), 14830) diff --git a/tests/test_modeling_tf_auto.py b/tests/test_modeling_tf_auto.py index 0b34ee4834..56d5f3efbe 100644 --- a/tests/test_modeling_tf_auto.py +++ b/tests/test_modeling_tf_auto.py @@ -99,3 +99,5 @@ class TFAutoModelTest(unittest.TestCase): logging.basicConfig(level=logging.INFO) model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER) self.assertIsInstance(model, TFBertForMaskedLM) + self.assertEqual(model.num_parameters(), 14830) + self.assertEqual(model.num_parameters(only_trainable=True), 14830)