Add "tie_word_embeddings" config param (#6692)

* add tie_word_embeddings

* correct word embeddings in modeling utils

* make style

* make config param only relevant for torch

* make style

* correct typo

* delete deprecated arg in transo-xl
This commit is contained in:
Patrick von Platen
2020-08-26 10:58:21 +02:00
committed by GitHub
parent fa8ee8e855
commit 925f34bbbd
8 changed files with 31 additions and 35 deletions

View File

@@ -413,7 +413,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
the weights instead.
"""
output_embeddings = self.get_output_embeddings()
if output_embeddings is not None:
if output_embeddings is not None and self.config.tie_word_embeddings:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
if self.config.is_encoder_decoder and self.config.tie_encoder_decoder: