diff --git a/src/transformers/configuration_auto.py b/src/transformers/configuration_auto.py index 3b112704cc..c363c89e3d 100644 --- a/src/transformers/configuration_auto.py +++ b/src/transformers/configuration_auto.py @@ -17,6 +17,7 @@ import logging from collections import OrderedDict +from importlib import import_module from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig @@ -100,6 +101,20 @@ class AutoConfig: "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method." ) + @classmethod + def config_class_for_model_class(cls, model_class): + module = import_module(model_class.__module__) + return next( + ( + module_attribute + for module_attribute_name in dir(module) + if module_attribute_name.endswith("Config") + for module_attribute in (getattr(module, module_attribute_name),) + if issubclass(module_attribute, PretrainedConfig) + ), + None, + ) + @classmethod def for_model(cls, model_type, *args, **kwargs): for pattern, config_class in CONFIG_MAPPING.items(): diff --git a/src/transformers/modeling_tf_albert.py b/src/transformers/modeling_tf_albert.py index b27f0eb8fa..698a1523d6 100644 --- a/src/transformers/modeling_tf_albert.py +++ b/src/transformers/modeling_tf_albert.py @@ -23,7 +23,7 @@ import tensorflow as tf from .configuration_albert import AlbertConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_bert import ACT2FN, TFBertSelfAttention -from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list +from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list logger = logging.getLogger(__name__) @@ -478,9 +478,10 @@ class TFAlbertMLMHead(tf.keras.layers.Layer): return hidden_states -class TFAlbertMainLayer(TFMainLayer): +@keras_serializable +class TFAlbertMainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) + super().__init__(**kwargs) self.num_hidden_layers = config.num_hidden_layers self.embeddings = TFAlbertEmbeddings(config, name="embeddings") diff --git a/src/transformers/modeling_tf_bert.py b/src/transformers/modeling_tf_bert.py index 623319bcf1..30bfdb4347 100644 --- a/src/transformers/modeling_tf_bert.py +++ b/src/transformers/modeling_tf_bert.py @@ -23,7 +23,7 @@ import tensorflow as tf from .configuration_bert import BertConfig from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list +from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list logger = logging.getLogger(__name__) @@ -471,9 +471,10 @@ class TFBertNSPHead(tf.keras.layers.Layer): return seq_relationship_score -class TFBertMainLayer(TFMainLayer): +@keras_serializable +class TFBertMainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) + super().__init__(**kwargs) self.num_hidden_layers = config.num_hidden_layers self.embeddings = TFBertEmbeddings(config, name="embeddings") diff --git a/src/transformers/modeling_tf_ctrl.py b/src/transformers/modeling_tf_ctrl.py index 6e3d0b1b5a..d801b7cb50 100644 --- a/src/transformers/modeling_tf_ctrl.py +++ b/src/transformers/modeling_tf_ctrl.py @@ -23,7 +23,7 @@ import tensorflow as tf from .configuration_ctrl import CTRLConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, shape_list +from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list logger = logging.getLogger(__name__) @@ -164,9 +164,10 @@ class TFEncoderLayer(tf.keras.layers.Layer): return outputs -class TFCTRLMainLayer(TFMainLayer): +@keras_serializable +class TFCTRLMainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) + super().__init__(**kwargs) self.output_hidden_states = config.output_hidden_states self.output_attentions = config.output_attentions self.output_past = config.output_past diff --git a/src/transformers/modeling_tf_distilbert.py b/src/transformers/modeling_tf_distilbert.py index 1570dabc07..c58243965a 100644 --- a/src/transformers/modeling_tf_distilbert.py +++ b/src/transformers/modeling_tf_distilbert.py @@ -24,7 +24,7 @@ import tensorflow as tf from .configuration_distilbert import DistilBertConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, get_initializer, shape_list +from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, keras_serializable, shape_list logger = logging.getLogger(__name__) @@ -397,9 +397,10 @@ class TFTransformer(tf.keras.layers.Layer): return outputs # last-layer hidden state, (all hidden states), (all attentions) -class TFDistilBertMainLayer(TFMainLayer): +@keras_serializable +class TFDistilBertMainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) + super().__init__(**kwargs) self.num_hidden_layers = config.num_hidden_layers self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings diff --git a/src/transformers/modeling_tf_gpt2.py b/src/transformers/modeling_tf_gpt2.py index d1f1cc9147..7a9a72bd68 100644 --- a/src/transformers/modeling_tf_gpt2.py +++ b/src/transformers/modeling_tf_gpt2.py @@ -25,11 +25,11 @@ from .configuration_gpt2 import GPT2Config from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_utils import ( TFConv1D, - TFMainLayer, TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, + keras_serializable, shape_list, ) @@ -197,9 +197,10 @@ class TFBlock(tf.keras.layers.Layer): return outputs # x, present, (attentions) -class TFGPT2MainLayer(TFMainLayer): +@keras_serializable +class TFGPT2MainLayer(tf.keras.layers.Layer): def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) + super().__init__(*inputs, **kwargs) self.output_hidden_states = config.output_hidden_states self.output_attentions = config.output_attentions self.num_hidden_layers = config.n_layer diff --git a/src/transformers/modeling_tf_openai.py b/src/transformers/modeling_tf_openai.py index 99b6533a0c..954091a5bd 100644 --- a/src/transformers/modeling_tf_openai.py +++ b/src/transformers/modeling_tf_openai.py @@ -25,11 +25,11 @@ from .configuration_openai import OpenAIGPTConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_utils import ( TFConv1D, - TFMainLayer, TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, + keras_serializable, shape_list, ) @@ -198,9 +198,10 @@ class TFBlock(tf.keras.layers.Layer): return outputs # x, (attentions) -class TFOpenAIGPTMainLayer(TFMainLayer): +@keras_serializable +class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) + super().__init__(*inputs, **kwargs) self.output_hidden_states = config.output_hidden_states self.output_attentions = config.output_attentions self.num_hidden_layers = config.n_layer diff --git a/src/transformers/modeling_tf_roberta.py b/src/transformers/modeling_tf_roberta.py index 31fb43f1cc..3347d7e041 100644 --- a/src/transformers/modeling_tf_roberta.py +++ b/src/transformers/modeling_tf_roberta.py @@ -20,10 +20,11 @@ import logging import tensorflow as tf +from . import PretrainedConfig from .configuration_roberta import RobertaConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu -from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list +from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list logger = logging.getLogger(__name__) diff --git a/src/transformers/modeling_tf_t5.py b/src/transformers/modeling_tf_t5.py index 974f744dca..b0f56edcfa 100644 --- a/src/transformers/modeling_tf_t5.py +++ b/src/transformers/modeling_tf_t5.py @@ -25,7 +25,7 @@ import tensorflow as tf from .configuration_t5 import T5Config from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings -from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, shape_list +from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list logger = logging.getLogger(__name__) @@ -359,9 +359,10 @@ class TFT5Block(tf.keras.layers.Layer): # The full model without a specific pretrained or finetuning head is # provided as a tf.keras.layers.Layer usually called "TFT5MainLayer" #################################################### -class TFT5MainLayer(TFMainLayer): +@keras_serializable +class TFT5MainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) + super().__init__(**kwargs) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states self.is_decoder = config.is_decoder @@ -383,14 +384,21 @@ class TFT5MainLayer(TFMainLayer): def call( self, - hidden_states, + inputs, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, training=False, ): - + if isinstance(inputs, (tuple, list)): + hidden_states = inputs[0] + assert len(inputs) <= 1, "Too many inputs." + elif isinstance(inputs, dict): + hidden_states = inputs["hidden_states"] + assert len(inputs) <= 1, "Too many inputs." + else: + hidden_states = inputs batch_size, seq_length = shape_list(hidden_states)[:2] if attention_mask is None: attention_mask = tf.fill((batch_size, seq_length), 1) diff --git a/src/transformers/modeling_tf_transfo_xl.py b/src/transformers/modeling_tf_transfo_xl.py index 1a65cce874..612e7a711d 100644 --- a/src/transformers/modeling_tf_transfo_xl.py +++ b/src/transformers/modeling_tf_transfo_xl.py @@ -24,7 +24,7 @@ import tensorflow as tf from .configuration_transfo_xl import TransfoXLConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask -from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list +from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list logger = logging.getLogger(__name__) @@ -378,9 +378,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer): return embed -class TFTransfoXLMainLayer(TFMainLayer): +@keras_serializable +class TFTransfoXLMainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) + super().__init__(**kwargs) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index e2e12fb5b1..c024aa8704 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -47,21 +47,31 @@ class TFModelUtilsMixin: return self.count_params() -class TFMainLayer(tf.keras.layers.Layer): - """ - A common superclass for main layers of models, to support `get_config` and thus Keras JSON serialization. - """ +def keras_serializable(cls): + initializer = cls.__init__ - def __init__(self, config, **kwargs): - super().__init__(**kwargs) + def wrapped_init(self, config, *args, **kwargs): if isinstance(config, dict): - config = PretrainedConfig.from_dict(config) + from transformers import AutoConfig + + config = AutoConfig.config_class_for_model_class(cls).from_dict(config) + initializer(self, config, *args, **kwargs) self._transformers_config = config - def get_config(self): - cfg = super().get_config() - cfg["config"] = self._transformers_config.to_dict() - return cfg + cls.__init__ = wrapped_init + + if not hasattr(cls, "get_config"): + raise TypeError("Only use @keras_serializable on tf.keras.layers.Layer subclasses") + if hasattr(cls.get_config, "_is_default"): + + def get_config(self): + cfg = super(cls, self).get_config() + cfg["config"] = self._transformers_config.to_dict() + return cfg + + cls.get_config = get_config + + return tf.keras.utils.register_keras_serializable()(cls) class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): diff --git a/src/transformers/modeling_tf_xlm.py b/src/transformers/modeling_tf_xlm.py index 2e8f2fda64..2b7f49d7f0 100644 --- a/src/transformers/modeling_tf_xlm.py +++ b/src/transformers/modeling_tf_xlm.py @@ -26,11 +26,11 @@ import tensorflow as tf from .configuration_xlm import XLMConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_utils import ( - TFMainLayer, TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, + keras_serializable, shape_list, ) @@ -203,9 +203,10 @@ class TFTransformerFFN(tf.keras.layers.Layer): return x -class TFXLMMainLayer(TFMainLayer): +@keras_serializable +class TFXLMMainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) + super().__init__(**kwargs) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states diff --git a/src/transformers/modeling_tf_xlnet.py b/src/transformers/modeling_tf_xlnet.py index 0050546fb0..c899628d88 100644 --- a/src/transformers/modeling_tf_xlnet.py +++ b/src/transformers/modeling_tf_xlnet.py @@ -25,11 +25,11 @@ import tensorflow as tf from .configuration_xlnet import XLNetConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_utils import ( - TFMainLayer, TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, + keras_serializable, shape_list, ) @@ -349,9 +349,10 @@ class TFXLNetLMHead(tf.keras.layers.Layer): return hidden_states -class TFXLNetMainLayer(TFMainLayer): +@keras_serializable +class TFXLNetMainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) + super().__init__(**kwargs) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states self.output_past = config.output_past diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 6e9c967ba6..6b4fb24f5b 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -22,7 +22,6 @@ import unittest from importlib import import_module from transformers import is_tf_available, is_torch_available -from transformers.modeling_tf_utils import TFMainLayer from .utils import _tf_gpu_memory_limit, require_tf @@ -90,6 +89,7 @@ class TFModelTesterMixin: model.save_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname) after_outputs = model(inputs_dict) + self.assert_outputs_same(after_outputs, outputs) def test_keras_save_load(self): @@ -100,10 +100,14 @@ class TFModelTesterMixin: for model_class in self.all_model_classes for module in (import_module(model_class.__module__),) for module_member_name in dir(module) + if module_member_name.endswith("MainLayer") for module_member in (getattr(module, module_member_name),) - if isinstance(module_member, type) and TFMainLayer in module_member.__bases__ + if isinstance(module_member, type) and tf.keras.layers.Layer in module_member.__bases__ ) for main_layer_class in tf_main_layer_classes: + if main_layer_class.__name__ == "TFT5MainLayer": + # Not really a “main layer” as in the other models, as this one doesn't receive the test inputs directly + continue main_layer = main_layer_class(config) symbolic_inputs = { name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items() @@ -125,6 +129,7 @@ class TFModelTesterMixin: # Make sure we don't have nans out_1 = after_outputs[0].numpy() out_2 = outputs[0].numpy() + self.assertEqual(out_1.shape, out_2.shape) out_1 = out_1[~np.isnan(out_1)] out_2 = out_2[~np.isnan(out_2)] max_diff = np.amax(np.abs(out_1 - out_2))