Use class decorator instead of superclass

When supplied by Keras deserialization, the config parameter to initializers
will be a dict. So intercept it and convert to PretrainedConfig object (and
store in instance attribute for get_config to get at it) before passing to the
actual initializer. To accomplish this, and repeat as little code as possible,
use a class decorator on TF*MainLayer classes.
This commit is contained in:
Gunnlaugur Thor Briem
2020-03-03 22:31:38 +00:00
parent b8da16f390
commit 0c716ede8c
14 changed files with 94 additions and 46 deletions

View File

@@ -22,7 +22,6 @@ import unittest
from importlib import import_module
from transformers import is_tf_available, is_torch_available
from transformers.modeling_tf_utils import TFMainLayer
from .utils import _tf_gpu_memory_limit, require_tf
@@ -90,6 +89,7 @@ class TFModelTesterMixin:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
after_outputs = model(inputs_dict)
self.assert_outputs_same(after_outputs, outputs)
def test_keras_save_load(self):
@@ -100,10 +100,14 @@ class TFModelTesterMixin:
for model_class in self.all_model_classes
for module in (import_module(model_class.__module__),)
for module_member_name in dir(module)
if module_member_name.endswith("MainLayer")
for module_member in (getattr(module, module_member_name),)
if isinstance(module_member, type) and TFMainLayer in module_member.__bases__
if isinstance(module_member, type) and tf.keras.layers.Layer in module_member.__bases__
)
for main_layer_class in tf_main_layer_classes:
if main_layer_class.__name__ == "TFT5MainLayer":
# Not really a “main layer” as in the other models, as this one doesn't receive the test inputs directly
continue
main_layer = main_layer_class(config)
symbolic_inputs = {
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
@@ -125,6 +129,7 @@ class TFModelTesterMixin:
# Make sure we don't have nans
out_1 = after_outputs[0].numpy()
out_2 = outputs[0].numpy()
self.assertEqual(out_1.shape, out_2.shape)
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))