Merge pull request #1509 from julian-pani/patch-3
remove leftover usage of DUMMY_INPUTS
This commit is contained in:
@@ -198,7 +198,7 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
|
|||||||
tf_model = tf_model_class(pt_model.config)
|
tf_model = tf_model_class(pt_model.config)
|
||||||
|
|
||||||
if tf_inputs is None:
|
if tf_inputs is None:
|
||||||
tf_inputs = tf.constant(DUMMY_INPUTS)
|
tf_inputs = tf_model.dummy_inputs
|
||||||
|
|
||||||
if tf_inputs is not None:
|
if tf_inputs is not None:
|
||||||
tfo = tf_model(tf_inputs, training=False) # Make sure model is built
|
tfo = tf_model(tf_inputs, training=False) # Make sure model is built
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import os
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@@ -118,7 +119,7 @@ class TFCommonTestCases:
|
|||||||
tf_model = model_class(config)
|
tf_model = model_class(config)
|
||||||
pt_model = pt_model_class(config)
|
pt_model = pt_model_class(config)
|
||||||
|
|
||||||
# Check we can load pt model in tf and vice-versa (architecture similar)
|
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||||
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict)
|
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict)
|
||||||
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
|
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
|
||||||
|
|
||||||
@@ -132,6 +133,26 @@ class TFCommonTestCases:
|
|||||||
max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
|
max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
|
||||||
self.assertLessEqual(max_diff, 2e-2)
|
self.assertLessEqual(max_diff, 2e-2)
|
||||||
|
|
||||||
|
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
||||||
|
with TemporaryDirectory() as tmpdirname:
|
||||||
|
pt_checkpoint_path = os.path.join(tmpdirname, 'pt_model.bin')
|
||||||
|
torch.save(pt_model.state_dict(), pt_checkpoint_path)
|
||||||
|
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
|
||||||
|
|
||||||
|
tf_checkpoint_path = os.path.join(tmpdirname, 'tf_model.h5')
|
||||||
|
tf_model.save_weights(tf_checkpoint_path)
|
||||||
|
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
|
||||||
|
|
||||||
|
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
|
||||||
|
pt_model.eval()
|
||||||
|
pt_inputs_dict = dict((name, torch.from_numpy(key.numpy()).to(torch.long))
|
||||||
|
for name, key in inputs_dict.items())
|
||||||
|
with torch.no_grad():
|
||||||
|
pto = pt_model(**pt_inputs_dict)
|
||||||
|
tfo = tf_model(inputs_dict)
|
||||||
|
max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
|
||||||
|
self.assertLessEqual(max_diff, 2e-2)
|
||||||
|
|
||||||
def test_compile_tf_model(self):
|
def test_compile_tf_model(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user