From ba28170717fbce8bddae70d065846bded799f9f3 Mon Sep 17 00:00:00 2001 From: Gunnlaugur Thor Briem Date: Tue, 3 Mar 2020 14:00:30 +0000 Subject: [PATCH 01/12] Support keras JSON/HDF5 serialization of main layers Fixes #3101 --- src/transformers/modeling_tf_albert.py | 4 ++-- src/transformers/modeling_tf_bert.py | 6 +++--- src/transformers/modeling_tf_ctrl.py | 6 +++--- src/transformers/modeling_tf_distilbert.py | 6 +++--- src/transformers/modeling_tf_gpt2.py | 5 +++-- src/transformers/modeling_tf_openai.py | 3 ++- src/transformers/modeling_tf_t5.py | 6 +++--- src/transformers/modeling_tf_transfo_xl.py | 6 +++--- src/transformers/modeling_tf_utils.py | 17 +++++++++++++++++ src/transformers/modeling_tf_xlm.py | 13 ++++++++++--- src/transformers/modeling_tf_xlnet.py | 13 ++++++++++--- 11 files changed, 59 insertions(+), 26 deletions(-) diff --git a/src/transformers/modeling_tf_albert.py b/src/transformers/modeling_tf_albert.py index 64c9dad06a..b27f0eb8fa 100644 --- a/src/transformers/modeling_tf_albert.py +++ b/src/transformers/modeling_tf_albert.py @@ -23,7 +23,7 @@ import tensorflow as tf from .configuration_albert import AlbertConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_bert import ACT2FN, TFBertSelfAttention -from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list +from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list logger = logging.getLogger(__name__) @@ -478,7 +478,7 @@ class TFAlbertMLMHead(tf.keras.layers.Layer): return hidden_states -class TFAlbertMainLayer(tf.keras.layers.Layer): +class TFAlbertMainLayer(TFMainLayer): def __init__(self, config, **kwargs): super().__init__(config, **kwargs) self.num_hidden_layers = config.num_hidden_layers diff --git a/src/transformers/modeling_tf_bert.py b/src/transformers/modeling_tf_bert.py index 1904623581..623319bcf1 100644 --- a/src/transformers/modeling_tf_bert.py +++ b/src/transformers/modeling_tf_bert.py @@ -23,7 +23,7 @@ import tensorflow as tf from .configuration_bert import BertConfig from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list +from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list logger = logging.getLogger(__name__) @@ -471,9 +471,9 @@ class TFBertNSPHead(tf.keras.layers.Layer): return seq_relationship_score -class TFBertMainLayer(tf.keras.layers.Layer): +class TFBertMainLayer(TFMainLayer): def __init__(self, config, **kwargs): - super().__init__(**kwargs) + super().__init__(config, **kwargs) self.num_hidden_layers = config.num_hidden_layers self.embeddings = TFBertEmbeddings(config, name="embeddings") diff --git a/src/transformers/modeling_tf_ctrl.py b/src/transformers/modeling_tf_ctrl.py index 335421979c..6e3d0b1b5a 100644 --- a/src/transformers/modeling_tf_ctrl.py +++ b/src/transformers/modeling_tf_ctrl.py @@ -23,7 +23,7 @@ import tensorflow as tf from .configuration_ctrl import CTRLConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list +from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, shape_list logger = logging.getLogger(__name__) @@ -164,9 +164,9 @@ class TFEncoderLayer(tf.keras.layers.Layer): return outputs -class TFCTRLMainLayer(tf.keras.layers.Layer): +class TFCTRLMainLayer(TFMainLayer): def __init__(self, config, **kwargs): - super().__init__(**kwargs) + super().__init__(config, **kwargs) self.output_hidden_states = config.output_hidden_states self.output_attentions = config.output_attentions self.output_past = config.output_past diff --git a/src/transformers/modeling_tf_distilbert.py b/src/transformers/modeling_tf_distilbert.py index 6f6eaa3be0..1570dabc07 100644 --- a/src/transformers/modeling_tf_distilbert.py +++ b/src/transformers/modeling_tf_distilbert.py @@ -24,7 +24,7 @@ import tensorflow as tf from .configuration_distilbert import DistilBertConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, shape_list +from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, get_initializer, shape_list logger = logging.getLogger(__name__) @@ -397,9 +397,9 @@ class TFTransformer(tf.keras.layers.Layer): return outputs # last-layer hidden state, (all hidden states), (all attentions) -class TFDistilBertMainLayer(tf.keras.layers.Layer): +class TFDistilBertMainLayer(TFMainLayer): def __init__(self, config, **kwargs): - super().__init__(**kwargs) + super().__init__(config, **kwargs) self.num_hidden_layers = config.num_hidden_layers self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings diff --git a/src/transformers/modeling_tf_gpt2.py b/src/transformers/modeling_tf_gpt2.py index 7e9b102b6d..d1f1cc9147 100644 --- a/src/transformers/modeling_tf_gpt2.py +++ b/src/transformers/modeling_tf_gpt2.py @@ -25,6 +25,7 @@ from .configuration_gpt2 import GPT2Config from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_utils import ( TFConv1D, + TFMainLayer, TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, @@ -196,9 +197,9 @@ class TFBlock(tf.keras.layers.Layer): return outputs # x, present, (attentions) -class TFGPT2MainLayer(tf.keras.layers.Layer): +class TFGPT2MainLayer(TFMainLayer): def __init__(self, config, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) + super().__init__(config, *inputs, **kwargs) self.output_hidden_states = config.output_hidden_states self.output_attentions = config.output_attentions self.num_hidden_layers = config.n_layer diff --git a/src/transformers/modeling_tf_openai.py b/src/transformers/modeling_tf_openai.py index f04104db83..99b6533a0c 100644 --- a/src/transformers/modeling_tf_openai.py +++ b/src/transformers/modeling_tf_openai.py @@ -25,6 +25,7 @@ from .configuration_openai import OpenAIGPTConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_utils import ( TFConv1D, + TFMainLayer, TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, @@ -197,7 +198,7 @@ class TFBlock(tf.keras.layers.Layer): return outputs # x, (attentions) -class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): +class TFOpenAIGPTMainLayer(TFMainLayer): def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.output_hidden_states = config.output_hidden_states diff --git a/src/transformers/modeling_tf_t5.py b/src/transformers/modeling_tf_t5.py index db62e784b1..974f744dca 100644 --- a/src/transformers/modeling_tf_t5.py +++ b/src/transformers/modeling_tf_t5.py @@ -25,7 +25,7 @@ import tensorflow as tf from .configuration_t5 import T5Config from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings -from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list +from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, shape_list logger = logging.getLogger(__name__) @@ -359,9 +359,9 @@ class TFT5Block(tf.keras.layers.Layer): # The full model without a specific pretrained or finetuning head is # provided as a tf.keras.layers.Layer usually called "TFT5MainLayer" #################################################### -class TFT5MainLayer(tf.keras.layers.Layer): +class TFT5MainLayer(TFMainLayer): def __init__(self, config, **kwargs): - super().__init__(**kwargs) + super().__init__(config, **kwargs) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states self.is_decoder = config.is_decoder diff --git a/src/transformers/modeling_tf_transfo_xl.py b/src/transformers/modeling_tf_transfo_xl.py index 098a4c9143..1a65cce874 100644 --- a/src/transformers/modeling_tf_transfo_xl.py +++ b/src/transformers/modeling_tf_transfo_xl.py @@ -24,7 +24,7 @@ import tensorflow as tf from .configuration_transfo_xl import TransfoXLConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask -from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list +from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list logger = logging.getLogger(__name__) @@ -378,9 +378,9 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer): return embed -class TFTransfoXLMainLayer(tf.keras.layers.Layer): +class TFTransfoXLMainLayer(TFMainLayer): def __init__(self, config, **kwargs): - super().__init__(**kwargs) + super().__init__(config, **kwargs) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 43abdd9499..e2e12fb5b1 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -47,6 +47,23 @@ class TFModelUtilsMixin: return self.count_params() +class TFMainLayer(tf.keras.layers.Layer): + """ + A common superclass for main layers of models, to support `get_config` and thus Keras JSON serialization. + """ + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + if isinstance(config, dict): + config = PretrainedConfig.from_dict(config) + self._transformers_config = config + + def get_config(self): + cfg = super().get_config() + cfg["config"] = self._transformers_config.to_dict() + return cfg + + class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): r""" Base class for all TF models. diff --git a/src/transformers/modeling_tf_xlm.py b/src/transformers/modeling_tf_xlm.py index 6e94a7206e..2e8f2fda64 100644 --- a/src/transformers/modeling_tf_xlm.py +++ b/src/transformers/modeling_tf_xlm.py @@ -25,7 +25,14 @@ import tensorflow as tf from .configuration_xlm import XLMConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, shape_list +from .modeling_tf_utils import ( + TFMainLayer, + TFPreTrainedModel, + TFSequenceSummary, + TFSharedEmbeddings, + get_initializer, + shape_list, +) logger = logging.getLogger(__name__) @@ -196,9 +203,9 @@ class TFTransformerFFN(tf.keras.layers.Layer): return x -class TFXLMMainLayer(tf.keras.layers.Layer): +class TFXLMMainLayer(TFMainLayer): def __init__(self, config, **kwargs): - super().__init__(**kwargs) + super().__init__(config, **kwargs) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states diff --git a/src/transformers/modeling_tf_xlnet.py b/src/transformers/modeling_tf_xlnet.py index 87ebe16858..0050546fb0 100644 --- a/src/transformers/modeling_tf_xlnet.py +++ b/src/transformers/modeling_tf_xlnet.py @@ -24,7 +24,14 @@ import tensorflow as tf from .configuration_xlnet import XLNetConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, shape_list +from .modeling_tf_utils import ( + TFMainLayer, + TFPreTrainedModel, + TFSequenceSummary, + TFSharedEmbeddings, + get_initializer, + shape_list, +) logger = logging.getLogger(__name__) @@ -342,9 +349,9 @@ class TFXLNetLMHead(tf.keras.layers.Layer): return hidden_states -class TFXLNetMainLayer(tf.keras.layers.Layer): +class TFXLNetMainLayer(TFMainLayer): def __init__(self, config, **kwargs): - super().__init__(**kwargs) + super().__init__(config, **kwargs) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states self.output_past = config.output_past From b8da16f39074a80e742e98a6751f256351127174 Mon Sep 17 00:00:00 2001 From: Gunnlaugur Thor Briem Date: Tue, 3 Mar 2020 15:15:30 +0000 Subject: [PATCH 02/12] Add (failing) tests for Keras save/load --- tests/test_modeling_tf_common.py | 47 +++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 8cd53dfe19..6e9c967ba6 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -19,8 +19,10 @@ import os import random import tempfile import unittest +from importlib import import_module from transformers import is_tf_available, is_torch_available +from transformers.modeling_tf_utils import TFMainLayer from .utils import _tf_gpu_memory_limit, require_tf @@ -88,14 +90,45 @@ class TFModelTesterMixin: model.save_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname) after_outputs = model(inputs_dict) + self.assert_outputs_same(after_outputs, outputs) - # Make sure we don't have nans - out_1 = after_outputs[0].numpy() - out_2 = outputs[0].numpy() - out_1 = out_1[~np.isnan(out_1)] - out_2 = out_2[~np.isnan(out_2)] - max_diff = np.amax(np.abs(out_1 - out_2)) - self.assertLessEqual(max_diff, 1e-5) + def test_keras_save_load(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + tf_main_layer_classes = set( + module_member + for model_class in self.all_model_classes + for module in (import_module(model_class.__module__),) + for module_member_name in dir(module) + for module_member in (getattr(module, module_member_name),) + if isinstance(module_member, type) and TFMainLayer in module_member.__bases__ + ) + for main_layer_class in tf_main_layer_classes: + main_layer = main_layer_class(config) + symbolic_inputs = { + name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items() + } + model = tf.keras.Model(symbolic_inputs, outputs=main_layer(symbolic_inputs)) + outputs = model(inputs_dict) + + with tempfile.TemporaryDirectory() as tmpdirname: + filepath = os.path.join(tmpdirname, "keras_model.h5") + model.save(filepath) + model = tf.keras.models.load_model( + filepath, custom_objects={main_layer_class.__name__: main_layer_class} + ) + assert isinstance(model, tf.keras.Model) + after_outputs = model(inputs_dict) + self.assert_outputs_same(after_outputs, outputs) + + def assert_outputs_same(self, after_outputs, outputs): + # Make sure we don't have nans + out_1 = after_outputs[0].numpy() + out_2 = outputs[0].numpy() + out_1 = out_1[~np.isnan(out_1)] + out_2 = out_2[~np.isnan(out_2)] + max_diff = np.amax(np.abs(out_1 - out_2)) + self.assertLessEqual(max_diff, 1e-5) def test_pt_tf_model_equivalence(self): if not is_torch_available(): From 0c716ede8c654cd2cdd3ab924f099e5d8277ee68 Mon Sep 17 00:00:00 2001 From: Gunnlaugur Thor Briem Date: Tue, 3 Mar 2020 22:31:38 +0000 Subject: [PATCH 03/12] Use class decorator instead of superclass When supplied by Keras deserialization, the config parameter to initializers will be a dict. So intercept it and convert to PretrainedConfig object (and store in instance attribute for get_config to get at it) before passing to the actual initializer. To accomplish this, and repeat as little code as possible, use a class decorator on TF*MainLayer classes. --- src/transformers/configuration_auto.py | 15 ++++++++++ src/transformers/modeling_tf_albert.py | 7 +++-- src/transformers/modeling_tf_bert.py | 7 +++-- src/transformers/modeling_tf_ctrl.py | 7 +++-- src/transformers/modeling_tf_distilbert.py | 7 +++-- src/transformers/modeling_tf_gpt2.py | 7 +++-- src/transformers/modeling_tf_openai.py | 7 +++-- src/transformers/modeling_tf_roberta.py | 3 +- src/transformers/modeling_tf_t5.py | 18 ++++++++---- src/transformers/modeling_tf_transfo_xl.py | 7 +++-- src/transformers/modeling_tf_utils.py | 32 ++++++++++++++-------- src/transformers/modeling_tf_xlm.py | 7 +++-- src/transformers/modeling_tf_xlnet.py | 7 +++-- tests/test_modeling_tf_common.py | 9 ++++-- 14 files changed, 94 insertions(+), 46 deletions(-) diff --git a/src/transformers/configuration_auto.py b/src/transformers/configuration_auto.py index 3b112704cc..c363c89e3d 100644 --- a/src/transformers/configuration_auto.py +++ b/src/transformers/configuration_auto.py @@ -17,6 +17,7 @@ import logging from collections import OrderedDict +from importlib import import_module from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig @@ -100,6 +101,20 @@ class AutoConfig: "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method." ) + @classmethod + def config_class_for_model_class(cls, model_class): + module = import_module(model_class.__module__) + return next( + ( + module_attribute + for module_attribute_name in dir(module) + if module_attribute_name.endswith("Config") + for module_attribute in (getattr(module, module_attribute_name),) + if issubclass(module_attribute, PretrainedConfig) + ), + None, + ) + @classmethod def for_model(cls, model_type, *args, **kwargs): for pattern, config_class in CONFIG_MAPPING.items(): diff --git a/src/transformers/modeling_tf_albert.py b/src/transformers/modeling_tf_albert.py index b27f0eb8fa..698a1523d6 100644 --- a/src/transformers/modeling_tf_albert.py +++ b/src/transformers/modeling_tf_albert.py @@ -23,7 +23,7 @@ import tensorflow as tf from .configuration_albert import AlbertConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_bert import ACT2FN, TFBertSelfAttention -from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list +from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list logger = logging.getLogger(__name__) @@ -478,9 +478,10 @@ class TFAlbertMLMHead(tf.keras.layers.Layer): return hidden_states -class TFAlbertMainLayer(TFMainLayer): +@keras_serializable +class TFAlbertMainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) + super().__init__(**kwargs) self.num_hidden_layers = config.num_hidden_layers self.embeddings = TFAlbertEmbeddings(config, name="embeddings") diff --git a/src/transformers/modeling_tf_bert.py b/src/transformers/modeling_tf_bert.py index 623319bcf1..30bfdb4347 100644 --- a/src/transformers/modeling_tf_bert.py +++ b/src/transformers/modeling_tf_bert.py @@ -23,7 +23,7 @@ import tensorflow as tf from .configuration_bert import BertConfig from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list +from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list logger = logging.getLogger(__name__) @@ -471,9 +471,10 @@ class TFBertNSPHead(tf.keras.layers.Layer): return seq_relationship_score -class TFBertMainLayer(TFMainLayer): +@keras_serializable +class TFBertMainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) + super().__init__(**kwargs) self.num_hidden_layers = config.num_hidden_layers self.embeddings = TFBertEmbeddings(config, name="embeddings") diff --git a/src/transformers/modeling_tf_ctrl.py b/src/transformers/modeling_tf_ctrl.py index 6e3d0b1b5a..d801b7cb50 100644 --- a/src/transformers/modeling_tf_ctrl.py +++ b/src/transformers/modeling_tf_ctrl.py @@ -23,7 +23,7 @@ import tensorflow as tf from .configuration_ctrl import CTRLConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, shape_list +from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list logger = logging.getLogger(__name__) @@ -164,9 +164,10 @@ class TFEncoderLayer(tf.keras.layers.Layer): return outputs -class TFCTRLMainLayer(TFMainLayer): +@keras_serializable +class TFCTRLMainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) + super().__init__(**kwargs) self.output_hidden_states = config.output_hidden_states self.output_attentions = config.output_attentions self.output_past = config.output_past diff --git a/src/transformers/modeling_tf_distilbert.py b/src/transformers/modeling_tf_distilbert.py index 1570dabc07..c58243965a 100644 --- a/src/transformers/modeling_tf_distilbert.py +++ b/src/transformers/modeling_tf_distilbert.py @@ -24,7 +24,7 @@ import tensorflow as tf from .configuration_distilbert import DistilBertConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, get_initializer, shape_list +from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, keras_serializable, shape_list logger = logging.getLogger(__name__) @@ -397,9 +397,10 @@ class TFTransformer(tf.keras.layers.Layer): return outputs # last-layer hidden state, (all hidden states), (all attentions) -class TFDistilBertMainLayer(TFMainLayer): +@keras_serializable +class TFDistilBertMainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) + super().__init__(**kwargs) self.num_hidden_layers = config.num_hidden_layers self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings diff --git a/src/transformers/modeling_tf_gpt2.py b/src/transformers/modeling_tf_gpt2.py index d1f1cc9147..7a9a72bd68 100644 --- a/src/transformers/modeling_tf_gpt2.py +++ b/src/transformers/modeling_tf_gpt2.py @@ -25,11 +25,11 @@ from .configuration_gpt2 import GPT2Config from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_utils import ( TFConv1D, - TFMainLayer, TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, + keras_serializable, shape_list, ) @@ -197,9 +197,10 @@ class TFBlock(tf.keras.layers.Layer): return outputs # x, present, (attentions) -class TFGPT2MainLayer(TFMainLayer): +@keras_serializable +class TFGPT2MainLayer(tf.keras.layers.Layer): def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) + super().__init__(*inputs, **kwargs) self.output_hidden_states = config.output_hidden_states self.output_attentions = config.output_attentions self.num_hidden_layers = config.n_layer diff --git a/src/transformers/modeling_tf_openai.py b/src/transformers/modeling_tf_openai.py index 99b6533a0c..954091a5bd 100644 --- a/src/transformers/modeling_tf_openai.py +++ b/src/transformers/modeling_tf_openai.py @@ -25,11 +25,11 @@ from .configuration_openai import OpenAIGPTConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_utils import ( TFConv1D, - TFMainLayer, TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, + keras_serializable, shape_list, ) @@ -198,9 +198,10 @@ class TFBlock(tf.keras.layers.Layer): return outputs # x, (attentions) -class TFOpenAIGPTMainLayer(TFMainLayer): +@keras_serializable +class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) + super().__init__(*inputs, **kwargs) self.output_hidden_states = config.output_hidden_states self.output_attentions = config.output_attentions self.num_hidden_layers = config.n_layer diff --git a/src/transformers/modeling_tf_roberta.py b/src/transformers/modeling_tf_roberta.py index 31fb43f1cc..3347d7e041 100644 --- a/src/transformers/modeling_tf_roberta.py +++ b/src/transformers/modeling_tf_roberta.py @@ -20,10 +20,11 @@ import logging import tensorflow as tf +from . import PretrainedConfig from .configuration_roberta import RobertaConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu -from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list +from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list logger = logging.getLogger(__name__) diff --git a/src/transformers/modeling_tf_t5.py b/src/transformers/modeling_tf_t5.py index 974f744dca..b0f56edcfa 100644 --- a/src/transformers/modeling_tf_t5.py +++ b/src/transformers/modeling_tf_t5.py @@ -25,7 +25,7 @@ import tensorflow as tf from .configuration_t5 import T5Config from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings -from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, shape_list +from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list logger = logging.getLogger(__name__) @@ -359,9 +359,10 @@ class TFT5Block(tf.keras.layers.Layer): # The full model without a specific pretrained or finetuning head is # provided as a tf.keras.layers.Layer usually called "TFT5MainLayer" #################################################### -class TFT5MainLayer(TFMainLayer): +@keras_serializable +class TFT5MainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) + super().__init__(**kwargs) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states self.is_decoder = config.is_decoder @@ -383,14 +384,21 @@ class TFT5MainLayer(TFMainLayer): def call( self, - hidden_states, + inputs, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, training=False, ): - + if isinstance(inputs, (tuple, list)): + hidden_states = inputs[0] + assert len(inputs) <= 1, "Too many inputs." + elif isinstance(inputs, dict): + hidden_states = inputs["hidden_states"] + assert len(inputs) <= 1, "Too many inputs." + else: + hidden_states = inputs batch_size, seq_length = shape_list(hidden_states)[:2] if attention_mask is None: attention_mask = tf.fill((batch_size, seq_length), 1) diff --git a/src/transformers/modeling_tf_transfo_xl.py b/src/transformers/modeling_tf_transfo_xl.py index 1a65cce874..612e7a711d 100644 --- a/src/transformers/modeling_tf_transfo_xl.py +++ b/src/transformers/modeling_tf_transfo_xl.py @@ -24,7 +24,7 @@ import tensorflow as tf from .configuration_transfo_xl import TransfoXLConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask -from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list +from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list logger = logging.getLogger(__name__) @@ -378,9 +378,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer): return embed -class TFTransfoXLMainLayer(TFMainLayer): +@keras_serializable +class TFTransfoXLMainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) + super().__init__(**kwargs) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index e2e12fb5b1..c024aa8704 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -47,21 +47,31 @@ class TFModelUtilsMixin: return self.count_params() -class TFMainLayer(tf.keras.layers.Layer): - """ - A common superclass for main layers of models, to support `get_config` and thus Keras JSON serialization. - """ +def keras_serializable(cls): + initializer = cls.__init__ - def __init__(self, config, **kwargs): - super().__init__(**kwargs) + def wrapped_init(self, config, *args, **kwargs): if isinstance(config, dict): - config = PretrainedConfig.from_dict(config) + from transformers import AutoConfig + + config = AutoConfig.config_class_for_model_class(cls).from_dict(config) + initializer(self, config, *args, **kwargs) self._transformers_config = config - def get_config(self): - cfg = super().get_config() - cfg["config"] = self._transformers_config.to_dict() - return cfg + cls.__init__ = wrapped_init + + if not hasattr(cls, "get_config"): + raise TypeError("Only use @keras_serializable on tf.keras.layers.Layer subclasses") + if hasattr(cls.get_config, "_is_default"): + + def get_config(self): + cfg = super(cls, self).get_config() + cfg["config"] = self._transformers_config.to_dict() + return cfg + + cls.get_config = get_config + + return tf.keras.utils.register_keras_serializable()(cls) class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): diff --git a/src/transformers/modeling_tf_xlm.py b/src/transformers/modeling_tf_xlm.py index 2e8f2fda64..2b7f49d7f0 100644 --- a/src/transformers/modeling_tf_xlm.py +++ b/src/transformers/modeling_tf_xlm.py @@ -26,11 +26,11 @@ import tensorflow as tf from .configuration_xlm import XLMConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_utils import ( - TFMainLayer, TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, + keras_serializable, shape_list, ) @@ -203,9 +203,10 @@ class TFTransformerFFN(tf.keras.layers.Layer): return x -class TFXLMMainLayer(TFMainLayer): +@keras_serializable +class TFXLMMainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) + super().__init__(**kwargs) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states diff --git a/src/transformers/modeling_tf_xlnet.py b/src/transformers/modeling_tf_xlnet.py index 0050546fb0..c899628d88 100644 --- a/src/transformers/modeling_tf_xlnet.py +++ b/src/transformers/modeling_tf_xlnet.py @@ -25,11 +25,11 @@ import tensorflow as tf from .configuration_xlnet import XLNetConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_utils import ( - TFMainLayer, TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, + keras_serializable, shape_list, ) @@ -349,9 +349,10 @@ class TFXLNetLMHead(tf.keras.layers.Layer): return hidden_states -class TFXLNetMainLayer(TFMainLayer): +@keras_serializable +class TFXLNetMainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) + super().__init__(**kwargs) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states self.output_past = config.output_past diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 6e9c967ba6..6b4fb24f5b 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -22,7 +22,6 @@ import unittest from importlib import import_module from transformers import is_tf_available, is_torch_available -from transformers.modeling_tf_utils import TFMainLayer from .utils import _tf_gpu_memory_limit, require_tf @@ -90,6 +89,7 @@ class TFModelTesterMixin: model.save_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname) after_outputs = model(inputs_dict) + self.assert_outputs_same(after_outputs, outputs) def test_keras_save_load(self): @@ -100,10 +100,14 @@ class TFModelTesterMixin: for model_class in self.all_model_classes for module in (import_module(model_class.__module__),) for module_member_name in dir(module) + if module_member_name.endswith("MainLayer") for module_member in (getattr(module, module_member_name),) - if isinstance(module_member, type) and TFMainLayer in module_member.__bases__ + if isinstance(module_member, type) and tf.keras.layers.Layer in module_member.__bases__ ) for main_layer_class in tf_main_layer_classes: + if main_layer_class.__name__ == "TFT5MainLayer": + # Not really a “main layer” as in the other models, as this one doesn't receive the test inputs directly + continue main_layer = main_layer_class(config) symbolic_inputs = { name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items() @@ -125,6 +129,7 @@ class TFModelTesterMixin: # Make sure we don't have nans out_1 = after_outputs[0].numpy() out_2 = outputs[0].numpy() + self.assertEqual(out_1.shape, out_2.shape) out_1 = out_1[~np.isnan(out_1)] out_2 = out_2[~np.isnan(out_2)] max_diff = np.amax(np.abs(out_1 - out_2)) From 470753bcf566b573b82dde640c11344c6342ff6a Mon Sep 17 00:00:00 2001 From: Gunnlaugur Thor Briem Date: Tue, 3 Mar 2020 22:44:38 +0000 Subject: [PATCH 04/12] Put @keras_serializable only on layers it works on And only run the test on TF*MainLayer classes so marked. --- src/transformers/modeling_tf_distilbert.py | 1 - src/transformers/modeling_tf_openai.py | 1 - src/transformers/modeling_tf_t5.py | 1 - src/transformers/modeling_tf_utils.py | 1 + src/transformers/modeling_tf_xlm.py | 1 - tests/test_modeling_tf_common.py | 4 +--- 6 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_tf_distilbert.py b/src/transformers/modeling_tf_distilbert.py index c58243965a..64cd93863a 100644 --- a/src/transformers/modeling_tf_distilbert.py +++ b/src/transformers/modeling_tf_distilbert.py @@ -397,7 +397,6 @@ class TFTransformer(tf.keras.layers.Layer): return outputs # last-layer hidden state, (all hidden states), (all attentions) -@keras_serializable class TFDistilBertMainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) diff --git a/src/transformers/modeling_tf_openai.py b/src/transformers/modeling_tf_openai.py index 954091a5bd..50e0af3e6c 100644 --- a/src/transformers/modeling_tf_openai.py +++ b/src/transformers/modeling_tf_openai.py @@ -198,7 +198,6 @@ class TFBlock(tf.keras.layers.Layer): return outputs # x, (attentions) -@keras_serializable class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): def __init__(self, config, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/modeling_tf_t5.py b/src/transformers/modeling_tf_t5.py index b0f56edcfa..f233074895 100644 --- a/src/transformers/modeling_tf_t5.py +++ b/src/transformers/modeling_tf_t5.py @@ -359,7 +359,6 @@ class TFT5Block(tf.keras.layers.Layer): # The full model without a specific pretrained or finetuning head is # provided as a tf.keras.layers.Layer usually called "TFT5MainLayer" #################################################### -@keras_serializable class TFT5MainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index c024aa8704..e10afb37a8 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -71,6 +71,7 @@ def keras_serializable(cls): cls.get_config = get_config + cls._keras_serializable = True return tf.keras.utils.register_keras_serializable()(cls) diff --git a/src/transformers/modeling_tf_xlm.py b/src/transformers/modeling_tf_xlm.py index 2b7f49d7f0..8a3f200e6f 100644 --- a/src/transformers/modeling_tf_xlm.py +++ b/src/transformers/modeling_tf_xlm.py @@ -203,7 +203,6 @@ class TFTransformerFFN(tf.keras.layers.Layer): return x -@keras_serializable class TFXLMMainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 6b4fb24f5b..8f4cae7491 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -103,11 +103,9 @@ class TFModelTesterMixin: if module_member_name.endswith("MainLayer") for module_member in (getattr(module, module_member_name),) if isinstance(module_member, type) and tf.keras.layers.Layer in module_member.__bases__ + and getattr(module_member, '_keras_serializable', False) ) for main_layer_class in tf_main_layer_classes: - if main_layer_class.__name__ == "TFT5MainLayer": - # Not really a “main layer” as in the other models, as this one doesn't receive the test inputs directly - continue main_layer = main_layer_class(config) symbolic_inputs = { name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items() From 96c4990165f8096da3de30954811b42b731a286d Mon Sep 17 00:00:00 2001 From: Gunnlaugur Thor Briem Date: Tue, 3 Mar 2020 22:57:05 +0000 Subject: [PATCH 05/12] fix unused imports and style --- src/transformers/modeling_tf_distilbert.py | 2 +- src/transformers/modeling_tf_openai.py | 1 - src/transformers/modeling_tf_roberta.py | 3 +-- src/transformers/modeling_tf_t5.py | 2 +- src/transformers/modeling_tf_xlm.py | 9 +-------- tests/test_modeling_tf_common.py | 5 +++-- 6 files changed, 7 insertions(+), 15 deletions(-) diff --git a/src/transformers/modeling_tf_distilbert.py b/src/transformers/modeling_tf_distilbert.py index 64cd93863a..6f6eaa3be0 100644 --- a/src/transformers/modeling_tf_distilbert.py +++ b/src/transformers/modeling_tf_distilbert.py @@ -24,7 +24,7 @@ import tensorflow as tf from .configuration_distilbert import DistilBertConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, keras_serializable, shape_list +from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, shape_list logger = logging.getLogger(__name__) diff --git a/src/transformers/modeling_tf_openai.py b/src/transformers/modeling_tf_openai.py index 50e0af3e6c..6a97ae7786 100644 --- a/src/transformers/modeling_tf_openai.py +++ b/src/transformers/modeling_tf_openai.py @@ -29,7 +29,6 @@ from .modeling_tf_utils import ( TFSequenceSummary, TFSharedEmbeddings, get_initializer, - keras_serializable, shape_list, ) diff --git a/src/transformers/modeling_tf_roberta.py b/src/transformers/modeling_tf_roberta.py index 3347d7e041..31fb43f1cc 100644 --- a/src/transformers/modeling_tf_roberta.py +++ b/src/transformers/modeling_tf_roberta.py @@ -20,11 +20,10 @@ import logging import tensorflow as tf -from . import PretrainedConfig from .configuration_roberta import RobertaConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu -from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list +from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list logger = logging.getLogger(__name__) diff --git a/src/transformers/modeling_tf_t5.py b/src/transformers/modeling_tf_t5.py index f233074895..41520e044b 100644 --- a/src/transformers/modeling_tf_t5.py +++ b/src/transformers/modeling_tf_t5.py @@ -25,7 +25,7 @@ import tensorflow as tf from .configuration_t5 import T5Config from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings -from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list +from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list logger = logging.getLogger(__name__) diff --git a/src/transformers/modeling_tf_xlm.py b/src/transformers/modeling_tf_xlm.py index 8a3f200e6f..6e94a7206e 100644 --- a/src/transformers/modeling_tf_xlm.py +++ b/src/transformers/modeling_tf_xlm.py @@ -25,14 +25,7 @@ import tensorflow as tf from .configuration_xlm import XLMConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import ( - TFPreTrainedModel, - TFSequenceSummary, - TFSharedEmbeddings, - get_initializer, - keras_serializable, - shape_list, -) +from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, shape_list logger = logging.getLogger(__name__) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 8f4cae7491..7c93ea16cd 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -102,8 +102,9 @@ class TFModelTesterMixin: for module_member_name in dir(module) if module_member_name.endswith("MainLayer") for module_member in (getattr(module, module_member_name),) - if isinstance(module_member, type) and tf.keras.layers.Layer in module_member.__bases__ - and getattr(module_member, '_keras_serializable', False) + if isinstance(module_member, type) + and tf.keras.layers.Layer in module_member.__bases__ + and getattr(module_member, "_keras_serializable", False) ) for main_layer_class in tf_main_layer_classes: main_layer = main_layer_class(config) From 18f4b9274f1e20182047e2cc312b4ba3ed7b61bc Mon Sep 17 00:00:00 2001 From: Gunnlaugur Thor Briem Date: Wed, 4 Mar 2020 16:57:28 +0000 Subject: [PATCH 06/12] fix: work with Tensorflow < 2.1.0 tf.keras.utils.register_keras_serializable was added in TF 2.1.0, so don't rely on it being there; just decorate the class with it if it exists. --- src/transformers/modeling_tf_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index e10afb37a8..3c94c6bf60 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -72,7 +72,9 @@ def keras_serializable(cls): cls.get_config = get_config cls._keras_serializable = True - return tf.keras.utils.register_keras_serializable()(cls) + if hasattr(tf.keras.utils, "register_keras_serializable"): + cls = tf.keras.utils.register_keras_serializable()(cls) + return cls class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): From 6fe1cc08743ee8f8b76c622e5a0742413048e857 Mon Sep 17 00:00:00 2001 From: Gunnlaugur Thor Briem Date: Wed, 4 Mar 2020 23:24:13 +0000 Subject: [PATCH 07/12] fix: clean up inadvertent change in tf_t5 This was the beginnings of an attempt to address the test failure on this layer, and instead I backed out of making this layer keras-serializable at all ... so it was a mistake to commit this. --- src/transformers/modeling_tf_t5.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/transformers/modeling_tf_t5.py b/src/transformers/modeling_tf_t5.py index 41520e044b..db62e784b1 100644 --- a/src/transformers/modeling_tf_t5.py +++ b/src/transformers/modeling_tf_t5.py @@ -383,21 +383,14 @@ class TFT5MainLayer(tf.keras.layers.Layer): def call( self, - inputs, + hidden_states, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, training=False, ): - if isinstance(inputs, (tuple, list)): - hidden_states = inputs[0] - assert len(inputs) <= 1, "Too many inputs." - elif isinstance(inputs, dict): - hidden_states = inputs["hidden_states"] - assert len(inputs) <= 1, "Too many inputs." - else: - hidden_states = inputs + batch_size, seq_length = shape_list(hidden_states)[:2] if attention_mask is None: attention_mask = tf.fill((batch_size, seq_length), 1) From 4f338ed407a0eca44cbdb342beba0c9f23bbea1e Mon Sep 17 00:00:00 2001 From: Gunnlaugur Thor Briem Date: Wed, 4 Mar 2020 23:45:29 +0000 Subject: [PATCH 08/12] Explicit config_class instead of module inspection --- src/transformers/configuration_auto.py | 14 -------------- src/transformers/modeling_tf_albert.py | 2 ++ src/transformers/modeling_tf_bert.py | 2 ++ src/transformers/modeling_tf_ctrl.py | 2 ++ src/transformers/modeling_tf_gpt2.py | 2 ++ src/transformers/modeling_tf_transfo_xl.py | 2 ++ src/transformers/modeling_tf_utils.py | 8 +++++--- src/transformers/modeling_tf_xlnet.py | 2 ++ 8 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/transformers/configuration_auto.py b/src/transformers/configuration_auto.py index c363c89e3d..37fd2feae3 100644 --- a/src/transformers/configuration_auto.py +++ b/src/transformers/configuration_auto.py @@ -101,20 +101,6 @@ class AutoConfig: "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method." ) - @classmethod - def config_class_for_model_class(cls, model_class): - module = import_module(model_class.__module__) - return next( - ( - module_attribute - for module_attribute_name in dir(module) - if module_attribute_name.endswith("Config") - for module_attribute in (getattr(module, module_attribute_name),) - if issubclass(module_attribute, PretrainedConfig) - ), - None, - ) - @classmethod def for_model(cls, model_type, *args, **kwargs): for pattern, config_class in CONFIG_MAPPING.items(): diff --git a/src/transformers/modeling_tf_albert.py b/src/transformers/modeling_tf_albert.py index 698a1523d6..78b0077d8e 100644 --- a/src/transformers/modeling_tf_albert.py +++ b/src/transformers/modeling_tf_albert.py @@ -480,6 +480,8 @@ class TFAlbertMLMHead(tf.keras.layers.Layer): @keras_serializable class TFAlbertMainLayer(tf.keras.layers.Layer): + config_class = AlbertConfig + def __init__(self, config, **kwargs): super().__init__(**kwargs) self.num_hidden_layers = config.num_hidden_layers diff --git a/src/transformers/modeling_tf_bert.py b/src/transformers/modeling_tf_bert.py index 30bfdb4347..11ae8da6b8 100644 --- a/src/transformers/modeling_tf_bert.py +++ b/src/transformers/modeling_tf_bert.py @@ -473,6 +473,8 @@ class TFBertNSPHead(tf.keras.layers.Layer): @keras_serializable class TFBertMainLayer(tf.keras.layers.Layer): + config_class = BertConfig + def __init__(self, config, **kwargs): super().__init__(**kwargs) self.num_hidden_layers = config.num_hidden_layers diff --git a/src/transformers/modeling_tf_ctrl.py b/src/transformers/modeling_tf_ctrl.py index d801b7cb50..3c13671f4b 100644 --- a/src/transformers/modeling_tf_ctrl.py +++ b/src/transformers/modeling_tf_ctrl.py @@ -166,6 +166,8 @@ class TFEncoderLayer(tf.keras.layers.Layer): @keras_serializable class TFCTRLMainLayer(tf.keras.layers.Layer): + config_class = CTRLConfig + def __init__(self, config, **kwargs): super().__init__(**kwargs) self.output_hidden_states = config.output_hidden_states diff --git a/src/transformers/modeling_tf_gpt2.py b/src/transformers/modeling_tf_gpt2.py index 7a9a72bd68..18752c34aa 100644 --- a/src/transformers/modeling_tf_gpt2.py +++ b/src/transformers/modeling_tf_gpt2.py @@ -199,6 +199,8 @@ class TFBlock(tf.keras.layers.Layer): @keras_serializable class TFGPT2MainLayer(tf.keras.layers.Layer): + config_class = GPT2Config + def __init__(self, config, *inputs, **kwargs): super().__init__(*inputs, **kwargs) self.output_hidden_states = config.output_hidden_states diff --git a/src/transformers/modeling_tf_transfo_xl.py b/src/transformers/modeling_tf_transfo_xl.py index 612e7a711d..53c4dc7bef 100644 --- a/src/transformers/modeling_tf_transfo_xl.py +++ b/src/transformers/modeling_tf_transfo_xl.py @@ -380,6 +380,8 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer): @keras_serializable class TFTransfoXLMainLayer(tf.keras.layers.Layer): + config_class = TransfoXLConfig + def __init__(self, config, **kwargs): super().__init__(**kwargs) self.output_attentions = config.output_attentions diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 3c94c6bf60..122ee9c911 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -50,11 +50,13 @@ class TFModelUtilsMixin: def keras_serializable(cls): initializer = cls.__init__ + config_class = getattr(cls, "config_class", None) + if config_class is None: + raise AttributeError("Must set `config_class` to use @keras_serializable") + def wrapped_init(self, config, *args, **kwargs): if isinstance(config, dict): - from transformers import AutoConfig - - config = AutoConfig.config_class_for_model_class(cls).from_dict(config) + config = config_class.from_dict(config) initializer(self, config, *args, **kwargs) self._transformers_config = config diff --git a/src/transformers/modeling_tf_xlnet.py b/src/transformers/modeling_tf_xlnet.py index c899628d88..8797a22194 100644 --- a/src/transformers/modeling_tf_xlnet.py +++ b/src/transformers/modeling_tf_xlnet.py @@ -351,6 +351,8 @@ class TFXLNetLMHead(tf.keras.layers.Layer): @keras_serializable class TFXLNetMainLayer(tf.keras.layers.Layer): + config_class = XLNetConfig + def __init__(self, config, **kwargs): super().__init__(**kwargs) self.output_attentions = config.output_attentions From d262a5d48eda7be8342a1447010f18d59009187d Mon Sep 17 00:00:00 2001 From: Gunnlaugur Thor Briem Date: Thu, 5 Mar 2020 11:05:29 +0000 Subject: [PATCH 09/12] fix: remove unused import --- src/transformers/configuration_auto.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/configuration_auto.py b/src/transformers/configuration_auto.py index 37fd2feae3..3b112704cc 100644 --- a/src/transformers/configuration_auto.py +++ b/src/transformers/configuration_auto.py @@ -17,7 +17,6 @@ import logging from collections import OrderedDict -from importlib import import_module from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig From a355f4f0fcd56df48a39c58ae5541c4c2600fd92 Mon Sep 17 00:00:00 2001 From: Gunnlaugur Thor Briem Date: Thu, 5 Mar 2020 11:11:42 +0000 Subject: [PATCH 10/12] Add functools.wraps for wrapper initializer Preserve the original initializer function's metadata. See https://docs.python.org/3/library/functools.html#functools.update_wrapper --- src/transformers/modeling_tf_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 122ee9c911..29049f7371 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -14,8 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """TF general model utils.""" - - +import functools import logging import os @@ -54,6 +53,7 @@ def keras_serializable(cls): if config_class is None: raise AttributeError("Must set `config_class` to use @keras_serializable") + @functools.wraps(initializer) def wrapped_init(self, config, *args, **kwargs): if isinstance(config, dict): config = config_class.from_dict(config) From 4be01e5cbff5c449f6ef4c0cbb79ccd625dec156 Mon Sep 17 00:00:00 2001 From: Gunnlaugur Thor Briem Date: Thu, 5 Mar 2020 11:08:45 +0000 Subject: [PATCH 11/12] Use name transformers_config in Keras serialization Be explicit that this is config for the transformers package (as these layers may coexist with other custom stuff in a Keras model, plus the Keras container itself is called config, and config["config"] is not great) Add explicit error handling for initializer calls that have neither the `config` nor the `transformers_config` argument, or have both. --- src/transformers/modeling_tf_utils.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 29049f7371..accaed00ae 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -54,10 +54,20 @@ def keras_serializable(cls): raise AttributeError("Must set `config_class` to use @keras_serializable") @functools.wraps(initializer) - def wrapped_init(self, config, *args, **kwargs): - if isinstance(config, dict): - config = config_class.from_dict(config) - initializer(self, config, *args, **kwargs) + def wrapped_init(self, *args, **kwargs): + transformers_config = kwargs.pop("transformers_config", None) + config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.get("config", None) + if config is not None and transformers_config is not None: + raise ValueError("Must pass either `config` or `transformers_config`, not both") + elif config is not None: + # normal layer construction, call with unchanged args (config is already in there) + initializer(self, *args, **kwargs) + elif transformers_config is not None: + # Keras deserialization, convert dict to config + config = config_class.from_dict(transformers_config) + initializer(self, config, *args, **kwargs) + else: + raise ValueError("Must pass either `config` (PretrainedConfig) or `transformers_config` (dict)") self._transformers_config = config cls.__init__ = wrapped_init @@ -68,7 +78,7 @@ def keras_serializable(cls): def get_config(self): cfg = super(cls, self).get_config() - cfg["config"] = self._transformers_config.to_dict() + cfg["transformers_config"] = self._transformers_config.to_dict() return cfg cls.get_config = get_config From 4c91a3af941627297c4b81fb48be1c9cb5ae7ee0 Mon Sep 17 00:00:00 2001 From: Gunnlaugur Thor Briem Date: Thu, 5 Mar 2020 11:48:10 +0000 Subject: [PATCH 12/12] Document keras_serializable decorator --- src/transformers/modeling_tf_utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index accaed00ae..4fc00ff01a 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -47,6 +47,21 @@ class TFModelUtilsMixin: def keras_serializable(cls): + """ + Decorate a Keras Layer class to support Keras serialization. + + This is done by: + 1. adding a `transformers_config` dict to the Keras config dictionary in `get_config` (called by Keras at + serialization time + 2. wrapping `__init__` to accept that `transformers_config` dict (passed by Keras at deserialization time) and + convert it to a config object for the actual layer initializer + 3. registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does + not need to be supplied in `custom_objects` in the call to `tf.keras.models.load_model` + + :param cls: a tf.keras.layers.Layers subclass that accepts a `config` argument to its initializer (typically a + `TF*MainLayer` class in this project) + :return: the same class object, with modifications for Keras deserialization. + """ initializer = cls.__init__ config_class = getattr(cls, "config_class", None)