adding more tests on TF and pytorch serialization - updating configuration for better serialization
This commit is contained in:
@@ -25,8 +25,6 @@ import numpy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||
|
||||
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=''):
|
||||
""" Convert a TF 2.0 model variable name in a pytorch model weight name.
|
||||
|
||||
@@ -105,7 +103,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
||||
raise e
|
||||
|
||||
if tf_inputs is None:
|
||||
tf_inputs = tf.constant(DUMMY_INPUTS)
|
||||
tf_inputs = tf_model.dummy_inputs
|
||||
|
||||
if tf_inputs is not None:
|
||||
tfo = tf_model(tf_inputs, training=False) # Make sure model is built
|
||||
|
||||
Reference in New Issue
Block a user