Make dummy inputs a local variable in TFPreTrainedModel.
This commit is contained in:
@@ -51,7 +51,6 @@ class TFPreTrainedModel(tf.keras.Model):
|
|||||||
config_class = None
|
config_class = None
|
||||||
pretrained_model_archive_map = {}
|
pretrained_model_archive_map = {}
|
||||||
base_model_prefix = ""
|
base_model_prefix = ""
|
||||||
dummy_inputs = tf.constant(DUMMY_INPUTS) # dummy inputs to build the network
|
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
|
super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
|
||||||
@@ -266,14 +265,15 @@ class TFPreTrainedModel(tf.keras.Model):
|
|||||||
# Load from a PyTorch checkpoint
|
# Load from a PyTorch checkpoint
|
||||||
return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file)
|
return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file)
|
||||||
|
|
||||||
ret = model(model.dummy_inputs, training=False) # build the network with dummy inputs
|
dummy_inputs = tf.constant(DUMMY_INPUTS) # dummy inputs to build the network
|
||||||
|
ret = model(dummy_inputs, training=False) # build the network with dummy inputs
|
||||||
|
|
||||||
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
|
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
|
||||||
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
||||||
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
||||||
model.load_weights(resolved_archive_file, by_name=True)
|
model.load_weights(resolved_archive_file, by_name=True)
|
||||||
|
|
||||||
ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run
|
ret = model(dummy_inputs, training=False) # Make sure restore ops are run
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user