Add "tie_word_embeddings" config param (#6692)
* add tie_word_embeddings * correct word embeddings in modeling utils * make style * make config param only relevant for torch * make style * correct typo * delete deprecated arg in transo-xl
This commit is contained in:
committed by
GitHub
parent
fa8ee8e855
commit
925f34bbbd
@@ -165,9 +165,16 @@ class ReformerConfig(PretrainedConfig):
|
||||
num_hashes=1,
|
||||
pad_token_id=0,
|
||||
vocab_size=320,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, eos_token_id=eos_token_id, is_decoder=is_decoder, **kwargs)
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
is_decoder=is_decoder,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.hash_seed = hash_seed
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
|
||||
@@ -79,8 +80,6 @@ class TransfoXLConfig(PretrainedConfig):
|
||||
number of samples in sampled softmax
|
||||
adaptive (:obj:`boolean`, optional, defaults to :obj:`True`):
|
||||
use adaptive softmax
|
||||
tie_weight (:obj:`boolean`, optional, defaults to :obj:`True`):
|
||||
tie the word embedding and softmax weights
|
||||
dropout (:obj:`float`, optional, defaults to 0.1):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
dropatt (:obj:`float`, optional, defaults to 0):
|
||||
@@ -135,7 +134,6 @@ class TransfoXLConfig(PretrainedConfig):
|
||||
attn_type=0,
|
||||
sample_softmax=-1,
|
||||
adaptive=True,
|
||||
tie_weight=True,
|
||||
dropout=0.1,
|
||||
dropatt=0.0,
|
||||
untie_r=True,
|
||||
@@ -147,12 +145,17 @@ class TransfoXLConfig(PretrainedConfig):
|
||||
eos_token_id=0,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(eos_token_id=eos_token_id, **kwargs)
|
||||
if "tie_weight" in kwargs:
|
||||
warnings.warn(
|
||||
"The config parameter `tie_weight` is deprecated. Please use `tie_word_embeddings` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
kwargs["tie_word_embeddings"] = kwargs["tie_weight"]
|
||||
|
||||
super().__init__(eos_token_id=eos_token_id, **kwargs)
|
||||
self.vocab_size = vocab_size
|
||||
self.cutoffs = []
|
||||
self.cutoffs.extend(cutoffs)
|
||||
self.tie_weight = tie_weight
|
||||
if proj_share_all_but_first:
|
||||
self.tie_projs = [False] + [True] * len(self.cutoffs)
|
||||
else:
|
||||
|
||||
@@ -134,6 +134,7 @@ class PretrainedConfig(object):
|
||||
PyTorch specific parameters
|
||||
- **torchscript** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should be
|
||||
used with Torchscript.
|
||||
- **tie_word_embeddings** (:obj:`bool`, `optional`, defaults to :obj:`True`) -- Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the model has a output word embedding layer.
|
||||
|
||||
TensorFlow specific parameters
|
||||
- **use_bfloat16** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should
|
||||
@@ -150,6 +151,9 @@ class PretrainedConfig(object):
|
||||
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
|
||||
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
|
||||
self.pruned_heads = kwargs.pop("pruned_heads", {})
|
||||
self.tie_word_embeddings = kwargs.pop(
|
||||
"tie_word_embeddings", True
|
||||
) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
|
||||
|
||||
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
|
||||
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
|
||||
|
||||
@@ -647,14 +647,13 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
|
||||
self.sop_classifier = AlbertSOPHead(config)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
self._tie_or_clone_weights(self.predictions.decoder, self.albert.embeddings.word_embeddings)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.predictions.decoder
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.albert.embeddings.word_embeddings
|
||||
|
||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=AlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@@ -798,14 +797,13 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
|
||||
self.predictions = AlbertMLMHead(config)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
self._tie_or_clone_weights(self.predictions.decoder, self.albert.embeddings.word_embeddings)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.predictions.decoder
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.albert.embeddings.word_embeddings
|
||||
|
||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
|
||||
@@ -945,7 +945,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
|
||||
self.cls.predictions.dense = resized_dense
|
||||
self.cls.predictions.dense.to(self.device)
|
||||
|
||||
if output_embeddings is not None:
|
||||
if output_embeddings is not None and self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||
|
||||
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
|
||||
@@ -1060,7 +1060,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
|
||||
self.cls.predictions.dense = resized_dense
|
||||
self.cls.predictions.dense.to(self.device)
|
||||
|
||||
if output_embeddings is not None:
|
||||
if output_embeddings is not None and self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||
|
||||
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
|
||||
@@ -2155,10 +2155,6 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head.decoder
|
||||
|
||||
def tie_weights(self):
|
||||
# word embeddings are not tied in Reformer
|
||||
pass
|
||||
|
||||
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
@@ -2274,10 +2270,6 @@ class ReformerForMaskedLM(ReformerPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head.decoder
|
||||
|
||||
def tie_weights(self):
|
||||
# word embeddings are not tied in Reformer
|
||||
pass
|
||||
|
||||
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
@@ -2356,10 +2348,6 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
# word embeddings are not tied in Reformer
|
||||
pass
|
||||
|
||||
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
@@ -2459,10 +2447,6 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
# word embeddings are not tied in Reformer
|
||||
pass
|
||||
|
||||
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
|
||||
@@ -62,7 +62,7 @@ def build_tf_to_pytorch_map(model, config):
|
||||
zip(model.crit.out_layers, model.crit.out_projs, config.tie_projs)
|
||||
):
|
||||
layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i
|
||||
if config.tie_weight:
|
||||
if config.tie_word_embeddings:
|
||||
tf_to_pt_map.update({layer_str + "b": out_l.bias})
|
||||
else:
|
||||
raise NotImplementedError
|
||||
@@ -978,7 +978,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
||||
Run this to be sure output and input (adaptive) softmax weights are tied
|
||||
"""
|
||||
|
||||
if self.config.tie_weight:
|
||||
if self.config.tie_word_embeddings:
|
||||
for i in range(len(self.crit.out_layers)):
|
||||
self._tie_or_clone_weights(self.crit.out_layers[i], self.transformer.word_emb.emb_layers[i])
|
||||
if self.config.tie_projs:
|
||||
|
||||
@@ -413,7 +413,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
the weights instead.
|
||||
"""
|
||||
output_embeddings = self.get_output_embeddings()
|
||||
if output_embeddings is not None:
|
||||
if output_embeddings is not None and self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||
|
||||
if self.config.is_encoder_decoder and self.config.tie_encoder_decoder:
|
||||
|
||||
Reference in New Issue
Block a user