From 124409d0754f5b84ff47daea51dede1d3b37a8dd Mon Sep 17 00:00:00 2001 From: Filip Povolny Date: Tue, 5 Nov 2019 11:48:45 +0100 Subject: [PATCH] Make dummy inputs a property of TFPreTrainedModel. --- transformers/modeling_tf_utils.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/transformers/modeling_tf_utils.py b/transformers/modeling_tf_utils.py index 110a590f55..33cfdc503d 100644 --- a/transformers/modeling_tf_utils.py +++ b/transformers/modeling_tf_utils.py @@ -52,6 +52,15 @@ class TFPreTrainedModel(tf.keras.Model): pretrained_model_archive_map = {} base_model_prefix = "" + @property + def dummy_inputs(self): + """ Dummy inputs to build the network. + + Returns: + tf.Tensor with dummy inputs + """ + return tf.constant(DUMMY_INPUTS) + def __init__(self, config, *inputs, **kwargs): super(TFPreTrainedModel, self).__init__(*inputs, **kwargs) if not isinstance(config, PretrainedConfig): @@ -265,15 +274,14 @@ class TFPreTrainedModel(tf.keras.Model): # Load from a PyTorch checkpoint return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file) - dummy_inputs = tf.constant(DUMMY_INPUTS) # dummy inputs to build the network - ret = model(dummy_inputs, training=False) # build the network with dummy inputs + ret = model(model.dummy_inputs, training=False) # build the network with dummy inputs assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file) # 'by_name' allow us to do transfer learning by skipping/adding layers # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 model.load_weights(resolved_archive_file, by_name=True) - ret = model(dummy_inputs, training=False) # Make sure restore ops are run + ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run return model