adding more tests on TF and pytorch serialization - updating configuration for better serialization
This commit is contained in:
@@ -25,9 +25,11 @@ import tensorflow as tf
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||
|
||||
class TFPreTrainedModel(tf.keras.Model):
|
||||
r""" Base class for all TF models.
|
||||
@@ -48,8 +50,8 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
"""
|
||||
config_class = None
|
||||
pretrained_model_archive_map = {}
|
||||
load_pt_weights = lambda model, config, path: None
|
||||
base_model_prefix = ""
|
||||
dummy_inputs = tf.constant(DUMMY_INPUTS) # dummy inputs to build the network
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
|
||||
@@ -262,17 +264,16 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
|
||||
if from_pt:
|
||||
# Load from a PyTorch checkpoint
|
||||
return cls.load_pt_weights(model, resolved_archive_file)
|
||||
return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file)
|
||||
|
||||
inputs = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
||||
ret = model(inputs, training=False) # build the network with dummy inputs
|
||||
ret = model(model.dummy_inputs, training=False) # build the network with dummy inputs
|
||||
|
||||
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
|
||||
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
||||
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
||||
model.load_weights(resolved_archive_file, by_name=True)
|
||||
|
||||
ret = model(inputs, training=False) # Make sure restore ops are run
|
||||
ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run
|
||||
|
||||
return model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user