Add "tie_word_embeddings" config param (#6692)
* add tie_word_embeddings * correct word embeddings in modeling utils * make style * make config param only relevant for torch * make style * correct typo * delete deprecated arg in transo-xl
This commit is contained in:
committed by
GitHub
parent
fa8ee8e855
commit
925f34bbbd
@@ -165,9 +165,16 @@ class ReformerConfig(PretrainedConfig):
|
|||||||
num_hashes=1,
|
num_hashes=1,
|
||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
vocab_size=320,
|
vocab_size=320,
|
||||||
|
tie_word_embeddings=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(pad_token_id=pad_token_id, eos_token_id=eos_token_id, is_decoder=is_decoder, **kwargs)
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
is_decoder=is_decoder,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
self.hash_seed = hash_seed
|
self.hash_seed = hash_seed
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
@@ -79,8 +80,6 @@ class TransfoXLConfig(PretrainedConfig):
|
|||||||
number of samples in sampled softmax
|
number of samples in sampled softmax
|
||||||
adaptive (:obj:`boolean`, optional, defaults to :obj:`True`):
|
adaptive (:obj:`boolean`, optional, defaults to :obj:`True`):
|
||||||
use adaptive softmax
|
use adaptive softmax
|
||||||
tie_weight (:obj:`boolean`, optional, defaults to :obj:`True`):
|
|
||||||
tie the word embedding and softmax weights
|
|
||||||
dropout (:obj:`float`, optional, defaults to 0.1):
|
dropout (:obj:`float`, optional, defaults to 0.1):
|
||||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||||
dropatt (:obj:`float`, optional, defaults to 0):
|
dropatt (:obj:`float`, optional, defaults to 0):
|
||||||
@@ -135,7 +134,6 @@ class TransfoXLConfig(PretrainedConfig):
|
|||||||
attn_type=0,
|
attn_type=0,
|
||||||
sample_softmax=-1,
|
sample_softmax=-1,
|
||||||
adaptive=True,
|
adaptive=True,
|
||||||
tie_weight=True,
|
|
||||||
dropout=0.1,
|
dropout=0.1,
|
||||||
dropatt=0.0,
|
dropatt=0.0,
|
||||||
untie_r=True,
|
untie_r=True,
|
||||||
@@ -147,12 +145,17 @@ class TransfoXLConfig(PretrainedConfig):
|
|||||||
eos_token_id=0,
|
eos_token_id=0,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(eos_token_id=eos_token_id, **kwargs)
|
if "tie_weight" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"The config parameter `tie_weight` is deprecated. Please use `tie_word_embeddings` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
kwargs["tie_word_embeddings"] = kwargs["tie_weight"]
|
||||||
|
|
||||||
|
super().__init__(eos_token_id=eos_token_id, **kwargs)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.cutoffs = []
|
self.cutoffs = []
|
||||||
self.cutoffs.extend(cutoffs)
|
self.cutoffs.extend(cutoffs)
|
||||||
self.tie_weight = tie_weight
|
|
||||||
if proj_share_all_but_first:
|
if proj_share_all_but_first:
|
||||||
self.tie_projs = [False] + [True] * len(self.cutoffs)
|
self.tie_projs = [False] + [True] * len(self.cutoffs)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -134,6 +134,7 @@ class PretrainedConfig(object):
|
|||||||
PyTorch specific parameters
|
PyTorch specific parameters
|
||||||
- **torchscript** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should be
|
- **torchscript** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should be
|
||||||
used with Torchscript.
|
used with Torchscript.
|
||||||
|
- **tie_word_embeddings** (:obj:`bool`, `optional`, defaults to :obj:`True`) -- Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the model has a output word embedding layer.
|
||||||
|
|
||||||
TensorFlow specific parameters
|
TensorFlow specific parameters
|
||||||
- **use_bfloat16** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should
|
- **use_bfloat16** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should
|
||||||
@@ -150,6 +151,9 @@ class PretrainedConfig(object):
|
|||||||
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
|
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
|
||||||
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
|
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
|
||||||
self.pruned_heads = kwargs.pop("pruned_heads", {})
|
self.pruned_heads = kwargs.pop("pruned_heads", {})
|
||||||
|
self.tie_word_embeddings = kwargs.pop(
|
||||||
|
"tie_word_embeddings", True
|
||||||
|
) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
|
||||||
|
|
||||||
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
|
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
|
||||||
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
|
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
|
||||||
|
|||||||
@@ -647,14 +647,13 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
|
|||||||
self.sop_classifier = AlbertSOPHead(config)
|
self.sop_classifier = AlbertSOPHead(config)
|
||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
self.tie_weights()
|
|
||||||
|
|
||||||
def tie_weights(self):
|
|
||||||
self._tie_or_clone_weights(self.predictions.decoder, self.albert.embeddings.word_embeddings)
|
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.predictions.decoder
|
return self.predictions.decoder
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.albert.embeddings.word_embeddings
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=AlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=AlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
@@ -798,14 +797,13 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
|
|||||||
self.predictions = AlbertMLMHead(config)
|
self.predictions = AlbertMLMHead(config)
|
||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
self.tie_weights()
|
|
||||||
|
|
||||||
def tie_weights(self):
|
|
||||||
self._tie_or_clone_weights(self.predictions.decoder, self.albert.embeddings.word_embeddings)
|
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.predictions.decoder
|
return self.predictions.decoder
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.albert.embeddings.word_embeddings
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
|
|||||||
@@ -945,7 +945,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
|
|||||||
self.cls.predictions.dense = resized_dense
|
self.cls.predictions.dense = resized_dense
|
||||||
self.cls.predictions.dense.to(self.device)
|
self.cls.predictions.dense.to(self.device)
|
||||||
|
|
||||||
if output_embeddings is not None:
|
if output_embeddings is not None and self.config.tie_word_embeddings:
|
||||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
|
||||||
@@ -1060,7 +1060,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
|
|||||||
self.cls.predictions.dense = resized_dense
|
self.cls.predictions.dense = resized_dense
|
||||||
self.cls.predictions.dense.to(self.device)
|
self.cls.predictions.dense.to(self.device)
|
||||||
|
|
||||||
if output_embeddings is not None:
|
if output_embeddings is not None and self.config.tie_word_embeddings:
|
||||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||||
|
|||||||
@@ -2155,10 +2155,6 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head.decoder
|
return self.lm_head.decoder
|
||||||
|
|
||||||
def tie_weights(self):
|
|
||||||
# word embeddings are not tied in Reformer
|
|
||||||
pass
|
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
@@ -2274,10 +2270,6 @@ class ReformerForMaskedLM(ReformerPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head.decoder
|
return self.lm_head.decoder
|
||||||
|
|
||||||
def tie_weights(self):
|
|
||||||
# word embeddings are not tied in Reformer
|
|
||||||
pass
|
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
@@ -2356,10 +2348,6 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def tie_weights(self):
|
|
||||||
# word embeddings are not tied in Reformer
|
|
||||||
pass
|
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
@@ -2459,10 +2447,6 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def tie_weights(self):
|
|
||||||
# word embeddings are not tied in Reformer
|
|
||||||
pass
|
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ def build_tf_to_pytorch_map(model, config):
|
|||||||
zip(model.crit.out_layers, model.crit.out_projs, config.tie_projs)
|
zip(model.crit.out_layers, model.crit.out_projs, config.tie_projs)
|
||||||
):
|
):
|
||||||
layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i
|
layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i
|
||||||
if config.tie_weight:
|
if config.tie_word_embeddings:
|
||||||
tf_to_pt_map.update({layer_str + "b": out_l.bias})
|
tf_to_pt_map.update({layer_str + "b": out_l.bias})
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -978,7 +978,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
|||||||
Run this to be sure output and input (adaptive) softmax weights are tied
|
Run this to be sure output and input (adaptive) softmax weights are tied
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.config.tie_weight:
|
if self.config.tie_word_embeddings:
|
||||||
for i in range(len(self.crit.out_layers)):
|
for i in range(len(self.crit.out_layers)):
|
||||||
self._tie_or_clone_weights(self.crit.out_layers[i], self.transformer.word_emb.emb_layers[i])
|
self._tie_or_clone_weights(self.crit.out_layers[i], self.transformer.word_emb.emb_layers[i])
|
||||||
if self.config.tie_projs:
|
if self.config.tie_projs:
|
||||||
|
|||||||
@@ -413,7 +413,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
the weights instead.
|
the weights instead.
|
||||||
"""
|
"""
|
||||||
output_embeddings = self.get_output_embeddings()
|
output_embeddings = self.get_output_embeddings()
|
||||||
if output_embeddings is not None:
|
if output_embeddings is not None and self.config.tie_word_embeddings:
|
||||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||||
|
|
||||||
if self.config.is_encoder_decoder and self.config.tie_encoder_decoder:
|
if self.config.is_encoder_decoder and self.config.tie_encoder_decoder:
|
||||||
|
|||||||
Reference in New Issue
Block a user