Make dummy inputs a local variable in TFPreTrainedModel.
This commit is contained in:
@@ -51,7 +51,6 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
config_class = None
|
||||
pretrained_model_archive_map = {}
|
||||
base_model_prefix = ""
|
||||
dummy_inputs = tf.constant(DUMMY_INPUTS) # dummy inputs to build the network
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
|
||||
@@ -266,14 +265,15 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
# Load from a PyTorch checkpoint
|
||||
return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file)
|
||||
|
||||
ret = model(model.dummy_inputs, training=False) # build the network with dummy inputs
|
||||
dummy_inputs = tf.constant(DUMMY_INPUTS) # dummy inputs to build the network
|
||||
ret = model(dummy_inputs, training=False) # build the network with dummy inputs
|
||||
|
||||
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
|
||||
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
||||
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
||||
model.load_weights(resolved_archive_file, by_name=True)
|
||||
|
||||
ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run
|
||||
ret = model(dummy_inputs, training=False) # Make sure restore ops are run
|
||||
|
||||
return model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user