From ba28170717fbce8bddae70d065846bded799f9f3 Mon Sep 17 00:00:00 2001 From: Gunnlaugur Thor Briem Date: Tue, 3 Mar 2020 14:00:30 +0000 Subject: [PATCH] Support keras JSON/HDF5 serialization of main layers Fixes #3101 --- src/transformers/modeling_tf_albert.py | 4 ++-- src/transformers/modeling_tf_bert.py | 6 +++--- src/transformers/modeling_tf_ctrl.py | 6 +++--- src/transformers/modeling_tf_distilbert.py | 6 +++--- src/transformers/modeling_tf_gpt2.py | 5 +++-- src/transformers/modeling_tf_openai.py | 3 ++- src/transformers/modeling_tf_t5.py | 6 +++--- src/transformers/modeling_tf_transfo_xl.py | 6 +++--- src/transformers/modeling_tf_utils.py | 17 +++++++++++++++++ src/transformers/modeling_tf_xlm.py | 13 ++++++++++--- src/transformers/modeling_tf_xlnet.py | 13 ++++++++++--- 11 files changed, 59 insertions(+), 26 deletions(-) diff --git a/src/transformers/modeling_tf_albert.py b/src/transformers/modeling_tf_albert.py index 64c9dad06a..b27f0eb8fa 100644 --- a/src/transformers/modeling_tf_albert.py +++ b/src/transformers/modeling_tf_albert.py @@ -23,7 +23,7 @@ import tensorflow as tf from .configuration_albert import AlbertConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_bert import ACT2FN, TFBertSelfAttention -from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list +from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list logger = logging.getLogger(__name__) @@ -478,7 +478,7 @@ class TFAlbertMLMHead(tf.keras.layers.Layer): return hidden_states -class TFAlbertMainLayer(tf.keras.layers.Layer): +class TFAlbertMainLayer(TFMainLayer): def __init__(self, config, **kwargs): super().__init__(config, **kwargs) self.num_hidden_layers = config.num_hidden_layers diff --git a/src/transformers/modeling_tf_bert.py b/src/transformers/modeling_tf_bert.py index 1904623581..623319bcf1 100644 --- a/src/transformers/modeling_tf_bert.py +++ b/src/transformers/modeling_tf_bert.py @@ -23,7 +23,7 @@ import tensorflow as tf from .configuration_bert import BertConfig from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list +from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list logger = logging.getLogger(__name__) @@ -471,9 +471,9 @@ class TFBertNSPHead(tf.keras.layers.Layer): return seq_relationship_score -class TFBertMainLayer(tf.keras.layers.Layer): +class TFBertMainLayer(TFMainLayer): def __init__(self, config, **kwargs): - super().__init__(**kwargs) + super().__init__(config, **kwargs) self.num_hidden_layers = config.num_hidden_layers self.embeddings = TFBertEmbeddings(config, name="embeddings") diff --git a/src/transformers/modeling_tf_ctrl.py b/src/transformers/modeling_tf_ctrl.py index 335421979c..6e3d0b1b5a 100644 --- a/src/transformers/modeling_tf_ctrl.py +++ b/src/transformers/modeling_tf_ctrl.py @@ -23,7 +23,7 @@ import tensorflow as tf from .configuration_ctrl import CTRLConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list +from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, shape_list logger = logging.getLogger(__name__) @@ -164,9 +164,9 @@ class TFEncoderLayer(tf.keras.layers.Layer): return outputs -class TFCTRLMainLayer(tf.keras.layers.Layer): +class TFCTRLMainLayer(TFMainLayer): def __init__(self, config, **kwargs): - super().__init__(**kwargs) + super().__init__(config, **kwargs) self.output_hidden_states = config.output_hidden_states self.output_attentions = config.output_attentions self.output_past = config.output_past diff --git a/src/transformers/modeling_tf_distilbert.py b/src/transformers/modeling_tf_distilbert.py index 6f6eaa3be0..1570dabc07 100644 --- a/src/transformers/modeling_tf_distilbert.py +++ b/src/transformers/modeling_tf_distilbert.py @@ -24,7 +24,7 @@ import tensorflow as tf from .configuration_distilbert import DistilBertConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, shape_list +from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, get_initializer, shape_list logger = logging.getLogger(__name__) @@ -397,9 +397,9 @@ class TFTransformer(tf.keras.layers.Layer): return outputs # last-layer hidden state, (all hidden states), (all attentions) -class TFDistilBertMainLayer(tf.keras.layers.Layer): +class TFDistilBertMainLayer(TFMainLayer): def __init__(self, config, **kwargs): - super().__init__(**kwargs) + super().__init__(config, **kwargs) self.num_hidden_layers = config.num_hidden_layers self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings diff --git a/src/transformers/modeling_tf_gpt2.py b/src/transformers/modeling_tf_gpt2.py index 7e9b102b6d..d1f1cc9147 100644 --- a/src/transformers/modeling_tf_gpt2.py +++ b/src/transformers/modeling_tf_gpt2.py @@ -25,6 +25,7 @@ from .configuration_gpt2 import GPT2Config from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_utils import ( TFConv1D, + TFMainLayer, TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, @@ -196,9 +197,9 @@ class TFBlock(tf.keras.layers.Layer): return outputs # x, present, (attentions) -class TFGPT2MainLayer(tf.keras.layers.Layer): +class TFGPT2MainLayer(TFMainLayer): def __init__(self, config, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) + super().__init__(config, *inputs, **kwargs) self.output_hidden_states = config.output_hidden_states self.output_attentions = config.output_attentions self.num_hidden_layers = config.n_layer diff --git a/src/transformers/modeling_tf_openai.py b/src/transformers/modeling_tf_openai.py index f04104db83..99b6533a0c 100644 --- a/src/transformers/modeling_tf_openai.py +++ b/src/transformers/modeling_tf_openai.py @@ -25,6 +25,7 @@ from .configuration_openai import OpenAIGPTConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_utils import ( TFConv1D, + TFMainLayer, TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, @@ -197,7 +198,7 @@ class TFBlock(tf.keras.layers.Layer): return outputs # x, (attentions) -class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): +class TFOpenAIGPTMainLayer(TFMainLayer): def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.output_hidden_states = config.output_hidden_states diff --git a/src/transformers/modeling_tf_t5.py b/src/transformers/modeling_tf_t5.py index db62e784b1..974f744dca 100644 --- a/src/transformers/modeling_tf_t5.py +++ b/src/transformers/modeling_tf_t5.py @@ -25,7 +25,7 @@ import tensorflow as tf from .configuration_t5 import T5Config from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings -from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list +from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, shape_list logger = logging.getLogger(__name__) @@ -359,9 +359,9 @@ class TFT5Block(tf.keras.layers.Layer): # The full model without a specific pretrained or finetuning head is # provided as a tf.keras.layers.Layer usually called "TFT5MainLayer" #################################################### -class TFT5MainLayer(tf.keras.layers.Layer): +class TFT5MainLayer(TFMainLayer): def __init__(self, config, **kwargs): - super().__init__(**kwargs) + super().__init__(config, **kwargs) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states self.is_decoder = config.is_decoder diff --git a/src/transformers/modeling_tf_transfo_xl.py b/src/transformers/modeling_tf_transfo_xl.py index 098a4c9143..1a65cce874 100644 --- a/src/transformers/modeling_tf_transfo_xl.py +++ b/src/transformers/modeling_tf_transfo_xl.py @@ -24,7 +24,7 @@ import tensorflow as tf from .configuration_transfo_xl import TransfoXLConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask -from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list +from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, get_initializer, shape_list logger = logging.getLogger(__name__) @@ -378,9 +378,9 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer): return embed -class TFTransfoXLMainLayer(tf.keras.layers.Layer): +class TFTransfoXLMainLayer(TFMainLayer): def __init__(self, config, **kwargs): - super().__init__(**kwargs) + super().__init__(config, **kwargs) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 43abdd9499..e2e12fb5b1 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -47,6 +47,23 @@ class TFModelUtilsMixin: return self.count_params() +class TFMainLayer(tf.keras.layers.Layer): + """ + A common superclass for main layers of models, to support `get_config` and thus Keras JSON serialization. + """ + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + if isinstance(config, dict): + config = PretrainedConfig.from_dict(config) + self._transformers_config = config + + def get_config(self): + cfg = super().get_config() + cfg["config"] = self._transformers_config.to_dict() + return cfg + + class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): r""" Base class for all TF models. diff --git a/src/transformers/modeling_tf_xlm.py b/src/transformers/modeling_tf_xlm.py index 6e94a7206e..2e8f2fda64 100644 --- a/src/transformers/modeling_tf_xlm.py +++ b/src/transformers/modeling_tf_xlm.py @@ -25,7 +25,14 @@ import tensorflow as tf from .configuration_xlm import XLMConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, shape_list +from .modeling_tf_utils import ( + TFMainLayer, + TFPreTrainedModel, + TFSequenceSummary, + TFSharedEmbeddings, + get_initializer, + shape_list, +) logger = logging.getLogger(__name__) @@ -196,9 +203,9 @@ class TFTransformerFFN(tf.keras.layers.Layer): return x -class TFXLMMainLayer(tf.keras.layers.Layer): +class TFXLMMainLayer(TFMainLayer): def __init__(self, config, **kwargs): - super().__init__(**kwargs) + super().__init__(config, **kwargs) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states diff --git a/src/transformers/modeling_tf_xlnet.py b/src/transformers/modeling_tf_xlnet.py index 87ebe16858..0050546fb0 100644 --- a/src/transformers/modeling_tf_xlnet.py +++ b/src/transformers/modeling_tf_xlnet.py @@ -24,7 +24,14 @@ import tensorflow as tf from .configuration_xlnet import XLNetConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable -from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, shape_list +from .modeling_tf_utils import ( + TFMainLayer, + TFPreTrainedModel, + TFSequenceSummary, + TFSharedEmbeddings, + get_initializer, + shape_list, +) logger = logging.getLogger(__name__) @@ -342,9 +349,9 @@ class TFXLNetLMHead(tf.keras.layers.Layer): return hidden_states -class TFXLNetMainLayer(tf.keras.layers.Layer): +class TFXLNetMainLayer(TFMainLayer): def __init__(self, config, **kwargs): - super().__init__(**kwargs) + super().__init__(config, **kwargs) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states self.output_past = config.output_past