From 925f34bbbd351025ab4d2d2752623e08c55ccf5c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 26 Aug 2020 10:58:21 +0200 Subject: [PATCH] Add "tie_word_embeddings" config param (#6692) * add tie_word_embeddings * correct word embeddings in modeling utils * make style * make config param only relevant for torch * make style * correct typo * delete deprecated arg in transo-xl --- src/transformers/configuration_reformer.py | 9 ++++++++- src/transformers/configuration_transfo_xl.py | 13 ++++++++----- src/transformers/configuration_utils.py | 4 ++++ src/transformers/modeling_albert.py | 14 ++++++-------- src/transformers/modeling_mobilebert.py | 4 ++-- src/transformers/modeling_reformer.py | 16 ---------------- src/transformers/modeling_transfo_xl.py | 4 ++-- src/transformers/modeling_utils.py | 2 +- 8 files changed, 31 insertions(+), 35 deletions(-) diff --git a/src/transformers/configuration_reformer.py b/src/transformers/configuration_reformer.py index dec1d726e3..eace75a00f 100755 --- a/src/transformers/configuration_reformer.py +++ b/src/transformers/configuration_reformer.py @@ -165,9 +165,16 @@ class ReformerConfig(PretrainedConfig): num_hashes=1, pad_token_id=0, vocab_size=320, + tie_word_embeddings=False, **kwargs ): - super().__init__(pad_token_id=pad_token_id, eos_token_id=eos_token_id, is_decoder=is_decoder, **kwargs) + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + is_decoder=is_decoder, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) self.hash_seed = hash_seed self.vocab_size = vocab_size diff --git a/src/transformers/configuration_transfo_xl.py b/src/transformers/configuration_transfo_xl.py index c3c6a22b82..b9c91151b1 100644 --- a/src/transformers/configuration_transfo_xl.py +++ b/src/transformers/configuration_transfo_xl.py @@ -17,6 +17,7 @@ import logging +import warnings from .configuration_utils import PretrainedConfig @@ -79,8 +80,6 @@ class TransfoXLConfig(PretrainedConfig): number of samples in sampled softmax adaptive (:obj:`boolean`, optional, defaults to :obj:`True`): use adaptive softmax - tie_weight (:obj:`boolean`, optional, defaults to :obj:`True`): - tie the word embedding and softmax weights dropout (:obj:`float`, optional, defaults to 0.1): The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. dropatt (:obj:`float`, optional, defaults to 0): @@ -135,7 +134,6 @@ class TransfoXLConfig(PretrainedConfig): attn_type=0, sample_softmax=-1, adaptive=True, - tie_weight=True, dropout=0.1, dropatt=0.0, untie_r=True, @@ -147,12 +145,17 @@ class TransfoXLConfig(PretrainedConfig): eos_token_id=0, **kwargs ): - super().__init__(eos_token_id=eos_token_id, **kwargs) + if "tie_weight" in kwargs: + warnings.warn( + "The config parameter `tie_weight` is deprecated. Please use `tie_word_embeddings` instead.", + FutureWarning, + ) + kwargs["tie_word_embeddings"] = kwargs["tie_weight"] + super().__init__(eos_token_id=eos_token_id, **kwargs) self.vocab_size = vocab_size self.cutoffs = [] self.cutoffs.extend(cutoffs) - self.tie_weight = tie_weight if proj_share_all_but_first: self.tie_projs = [False] + [True] * len(self.cutoffs) else: diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index fa9466326a..c82fe8d211 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -134,6 +134,7 @@ class PretrainedConfig(object): PyTorch specific parameters - **torchscript** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should be used with Torchscript. + - **tie_word_embeddings** (:obj:`bool`, `optional`, defaults to :obj:`True`) -- Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the model has a output word embedding layer. TensorFlow specific parameters - **use_bfloat16** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should @@ -150,6 +151,9 @@ class PretrainedConfig(object): self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models self.use_bfloat16 = kwargs.pop("use_bfloat16", False) self.pruned_heads = kwargs.pop("pruned_heads", {}) + self.tie_word_embeddings = kwargs.pop( + "tie_word_embeddings", True + ) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models. # Is decoder is used in encoder-decoder models to differentiate encoder from decoder self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False) diff --git a/src/transformers/modeling_albert.py b/src/transformers/modeling_albert.py index 76f0998952..fb43e7cb29 100755 --- a/src/transformers/modeling_albert.py +++ b/src/transformers/modeling_albert.py @@ -647,14 +647,13 @@ class AlbertForPreTraining(AlbertPreTrainedModel): self.sop_classifier = AlbertSOPHead(config) self.init_weights() - self.tie_weights() - - def tie_weights(self): - self._tie_or_clone_weights(self.predictions.decoder, self.albert.embeddings.word_embeddings) def get_output_embeddings(self): return self.predictions.decoder + def get_input_embeddings(self): + return self.albert.embeddings.word_embeddings + @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=AlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) def forward( @@ -798,14 +797,13 @@ class AlbertForMaskedLM(AlbertPreTrainedModel): self.predictions = AlbertMLMHead(config) self.init_weights() - self.tie_weights() - - def tie_weights(self): - self._tie_or_clone_weights(self.predictions.decoder, self.albert.embeddings.word_embeddings) def get_output_embeddings(self): return self.predictions.decoder + def get_input_embeddings(self): + return self.albert.embeddings.word_embeddings + @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, diff --git a/src/transformers/modeling_mobilebert.py b/src/transformers/modeling_mobilebert.py index 7ab605dcfc..ea75ab16a9 100644 --- a/src/transformers/modeling_mobilebert.py +++ b/src/transformers/modeling_mobilebert.py @@ -945,7 +945,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel): self.cls.predictions.dense = resized_dense self.cls.predictions.dense.to(self.device) - if output_embeddings is not None: + if output_embeddings is not None and self.config.tie_word_embeddings: self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING) @@ -1060,7 +1060,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel): self.cls.predictions.dense = resized_dense self.cls.predictions.dense.to(self.device) - if output_embeddings is not None: + if output_embeddings is not None and self.config.tie_word_embeddings: self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 2e8c5f31d3..1e20ffb39f 100755 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -2155,10 +2155,6 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): def get_output_embeddings(self): return self.lm_head.decoder - def tie_weights(self): - # word embeddings are not tied in Reformer - pass - @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, @@ -2274,10 +2270,6 @@ class ReformerForMaskedLM(ReformerPreTrainedModel): def get_output_embeddings(self): return self.lm_head.decoder - def tie_weights(self): - # word embeddings are not tied in Reformer - pass - @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, @@ -2356,10 +2348,6 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel): self.init_weights() - def tie_weights(self): - # word embeddings are not tied in Reformer - pass - @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, @@ -2459,10 +2447,6 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel): self.init_weights() - def tie_weights(self): - # word embeddings are not tied in Reformer - pass - @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, diff --git a/src/transformers/modeling_transfo_xl.py b/src/transformers/modeling_transfo_xl.py index f17855ce1d..87bf6d9b21 100644 --- a/src/transformers/modeling_transfo_xl.py +++ b/src/transformers/modeling_transfo_xl.py @@ -62,7 +62,7 @@ def build_tf_to_pytorch_map(model, config): zip(model.crit.out_layers, model.crit.out_projs, config.tie_projs) ): layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i - if config.tie_weight: + if config.tie_word_embeddings: tf_to_pt_map.update({layer_str + "b": out_l.bias}) else: raise NotImplementedError @@ -978,7 +978,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): Run this to be sure output and input (adaptive) softmax weights are tied """ - if self.config.tie_weight: + if self.config.tie_word_embeddings: for i in range(len(self.crit.out_layers)): self._tie_or_clone_weights(self.crit.out_layers[i], self.transformer.word_emb.emb_layers[i]) if self.config.tie_projs: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a6ff59ee7f..8c6c5cfdd8 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -413,7 +413,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): the weights instead. """ output_embeddings = self.get_output_embeddings() - if output_embeddings is not None: + if output_embeddings is not None and self.config.tie_word_embeddings: self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) if self.config.is_encoder_decoder and self.config.tie_encoder_decoder: