From 8df7dfd2a723465b0cca9f5e808a75e074482d02 Mon Sep 17 00:00:00 2001 From: Filip Povolny Date: Tue, 5 Nov 2019 11:09:16 +0100 Subject: [PATCH] Make dummy inputs a local variable in TFPreTrainedModel. --- transformers/modeling_tf_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformers/modeling_tf_utils.py b/transformers/modeling_tf_utils.py index a96e2765fd..110a590f55 100644 --- a/transformers/modeling_tf_utils.py +++ b/transformers/modeling_tf_utils.py @@ -51,7 +51,6 @@ class TFPreTrainedModel(tf.keras.Model): config_class = None pretrained_model_archive_map = {} base_model_prefix = "" - dummy_inputs = tf.constant(DUMMY_INPUTS) # dummy inputs to build the network def __init__(self, config, *inputs, **kwargs): super(TFPreTrainedModel, self).__init__(*inputs, **kwargs) @@ -266,14 +265,15 @@ class TFPreTrainedModel(tf.keras.Model): # Load from a PyTorch checkpoint return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file) - ret = model(model.dummy_inputs, training=False) # build the network with dummy inputs + dummy_inputs = tf.constant(DUMMY_INPUTS) # dummy inputs to build the network + ret = model(dummy_inputs, training=False) # build the network with dummy inputs assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file) # 'by_name' allow us to do transfer learning by skipping/adding layers # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 model.load_weights(resolved_archive_file, by_name=True) - ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run + ret = model(dummy_inputs, training=False) # Make sure restore ops are run return model