adding more tests on TF and pytorch serialization - updating configuration for better serialization

This commit is contained in:
thomwolf
2019-10-10 14:30:48 +02:00
parent bb04edb45b
commit da26bae61b
15 changed files with 90 additions and 148 deletions

View File

@@ -25,8 +25,6 @@ import numpy
logger = logging.getLogger(__name__)
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=''):
""" Convert a TF 2.0 model variable name in a pytorch model weight name.
@@ -105,7 +103,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
raise e
if tf_inputs is None:
tf_inputs = tf.constant(DUMMY_INPUTS)
tf_inputs = tf_model.dummy_inputs
if tf_inputs is not None:
tfo = tf_model(tf_inputs, training=False) # Make sure model is built