num_parameters helper
This commit is contained in:
@@ -20,6 +20,7 @@ import logging
|
||||
import os
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.keras.saving import hdf5_format
|
||||
|
||||
@@ -31,7 +32,22 @@ from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TFPreTrainedModel(tf.keras.Model):
|
||||
class TFModelUtils:
|
||||
"""
|
||||
A few utilities for `tf.keras.Model`s, to be used as a mixin.
|
||||
"""
|
||||
|
||||
def num_parameters(self, only_trainable: bool = False) -> int:
|
||||
"""
|
||||
Get number of (optionally, trainable) parameters in the model.
|
||||
"""
|
||||
if only_trainable:
|
||||
return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
|
||||
else:
|
||||
return self.count_params()
|
||||
|
||||
|
||||
class TFPreTrainedModel(tf.keras.Model, TFModelUtils):
|
||||
r""" Base class for all TF models.
|
||||
|
||||
:class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
|
||||
|
||||
Reference in New Issue
Block a user