num_parameters helper

This commit is contained in:
Julien Chaumond
2020-01-10 17:40:02 +00:00
parent 331065e62d
commit 84c0aa1868
4 changed files with 35 additions and 2 deletions

View File

@@ -20,6 +20,7 @@ import logging
import os
import h5py
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.saving import hdf5_format
@@ -31,7 +32,22 @@ from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__)
class TFPreTrainedModel(tf.keras.Model):
class TFModelUtils:
"""
A few utilities for `tf.keras.Model`s, to be used as a mixin.
"""
def num_parameters(self, only_trainable: bool = False) -> int:
"""
Get number of (optionally, trainable) parameters in the model.
"""
if only_trainable:
return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
else:
return self.count_params()
class TFPreTrainedModel(tf.keras.Model, TFModelUtils):
r""" Base class for all TF models.
:class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models