Use class decorator instead of superclass
When supplied by Keras deserialization, the config parameter to initializers will be a dict. So intercept it and convert to PretrainedConfig object (and store in instance attribute for get_config to get at it) before passing to the actual initializer. To accomplish this, and repeat as little code as possible, use a class decorator on TF*MainLayer classes.
This commit is contained in:
@@ -22,7 +22,6 @@ import unittest
|
||||
from importlib import import_module
|
||||
|
||||
from transformers import is_tf_available, is_torch_available
|
||||
from transformers.modeling_tf_utils import TFMainLayer
|
||||
|
||||
from .utils import _tf_gpu_memory_limit, require_tf
|
||||
|
||||
@@ -90,6 +89,7 @@ class TFModelTesterMixin:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
after_outputs = model(inputs_dict)
|
||||
|
||||
self.assert_outputs_same(after_outputs, outputs)
|
||||
|
||||
def test_keras_save_load(self):
|
||||
@@ -100,10 +100,14 @@ class TFModelTesterMixin:
|
||||
for model_class in self.all_model_classes
|
||||
for module in (import_module(model_class.__module__),)
|
||||
for module_member_name in dir(module)
|
||||
if module_member_name.endswith("MainLayer")
|
||||
for module_member in (getattr(module, module_member_name),)
|
||||
if isinstance(module_member, type) and TFMainLayer in module_member.__bases__
|
||||
if isinstance(module_member, type) and tf.keras.layers.Layer in module_member.__bases__
|
||||
)
|
||||
for main_layer_class in tf_main_layer_classes:
|
||||
if main_layer_class.__name__ == "TFT5MainLayer":
|
||||
# Not really a “main layer” as in the other models, as this one doesn't receive the test inputs directly
|
||||
continue
|
||||
main_layer = main_layer_class(config)
|
||||
symbolic_inputs = {
|
||||
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
|
||||
@@ -125,6 +129,7 @@ class TFModelTesterMixin:
|
||||
# Make sure we don't have nans
|
||||
out_1 = after_outputs[0].numpy()
|
||||
out_2 = outputs[0].numpy()
|
||||
self.assertEqual(out_1.shape, out_2.shape)
|
||||
out_1 = out_1[~np.isnan(out_1)]
|
||||
out_2 = out_2[~np.isnan(out_2)]
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
|
||||
Reference in New Issue
Block a user