Add tests that TF 2.0 model can be integrated with other Keras modules
This commit is contained in:
@@ -22,6 +22,7 @@ import random
|
|||||||
import shutil
|
import shutil
|
||||||
import unittest
|
import unittest
|
||||||
import uuid
|
import uuid
|
||||||
|
import tempfile
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import sys
|
import sys
|
||||||
@@ -36,6 +37,20 @@ if is_tf_available():
|
|||||||
else:
|
else:
|
||||||
pytestmark = pytest.mark.skip("Require TensorFlow")
|
pytestmark = pytest.mark.skip("Require TensorFlow")
|
||||||
|
|
||||||
|
if sys.version_info[0] == 2:
|
||||||
|
import cPickle as pickle
|
||||||
|
|
||||||
|
class TemporaryDirectory(object):
|
||||||
|
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
|
||||||
|
def __enter__(self):
|
||||||
|
self.name = tempfile.mkdtemp()
|
||||||
|
return self.name
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
shutil.rmtree(self.name)
|
||||||
|
else:
|
||||||
|
import pickle
|
||||||
|
TemporaryDirectory = tempfile.TemporaryDirectory
|
||||||
|
unicode = str
|
||||||
|
|
||||||
def _config_zero_init(config):
|
def _config_zero_init(config):
|
||||||
configs_no_init = copy.deepcopy(config)
|
configs_no_init = copy.deepcopy(config)
|
||||||
@@ -66,13 +81,25 @@ class TFCommonTestCases:
|
|||||||
# self.assertIn(param.data.mean().item(), [0.0, 1.0],
|
# self.assertIn(param.data.mean().item(), [0.0, 1.0],
|
||||||
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
|
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
|
||||||
|
|
||||||
|
def test_save_load(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
outputs = model(inputs_dict)
|
||||||
|
|
||||||
|
with TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
model = model_class.from_pretrained(tmpdirname)
|
||||||
|
after_outputs = model(inputs_dict)
|
||||||
|
max_diff = np.amax(np.abs(after_outputs[0].numpy() - outputs[0].numpy()))
|
||||||
|
self.assertLessEqual(max_diff, 1e-5)
|
||||||
|
|
||||||
def test_pt_tf_model_equivalence(self):
|
def test_pt_tf_model_equivalence(self):
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
return
|
return
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -99,6 +126,34 @@ class TFCommonTestCases:
|
|||||||
max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
|
max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
|
||||||
self.assertLessEqual(max_diff, 2e-2)
|
self.assertLessEqual(max_diff, 2e-2)
|
||||||
|
|
||||||
|
def test_compile_tf_model(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
input_ids = tf.keras.Input(batch_shape=(2, 2000), name='input_ids', dtype='int32')
|
||||||
|
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
|
||||||
|
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||||
|
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
# Prepare our model
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
# Let's load it from the disk to be sure we can use pretrained weights
|
||||||
|
with TemporaryDirectory() as tmpdirname:
|
||||||
|
outputs = model(inputs_dict) # build the model
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
model = model_class.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
outputs_dict = model(input_ids)
|
||||||
|
hidden_states = outputs_dict[0]
|
||||||
|
|
||||||
|
# Add a dense layer on top to test intetgration with other keras modules
|
||||||
|
outputs = tf.keras.layers.Dense(2, activation='softmax', name='outputs')(hidden_states)
|
||||||
|
|
||||||
|
# Compile extended model
|
||||||
|
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
|
||||||
|
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
|
||||||
|
|
||||||
def test_keyword_and_dict_args(self):
|
def test_keyword_and_dict_args(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user