Add "tie_word_embeddings" config param (#6692)

* add tie_word_embeddings

* correct word embeddings in modeling utils

* make style

* make config param only relevant for torch

* make style

* correct typo

* delete deprecated arg in transo-xl
This commit is contained in:
Patrick von Platen
2020-08-26 10:58:21 +02:00
committed by GitHub
parent fa8ee8e855
commit 925f34bbbd
8 changed files with 31 additions and 35 deletions

View File

@@ -945,7 +945,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
self.cls.predictions.dense = resized_dense
self.cls.predictions.dense.to(self.device)
if output_embeddings is not None:
if output_embeddings is not None and self.config.tie_word_embeddings:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
@@ -1060,7 +1060,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
self.cls.predictions.dense = resized_dense
self.cls.predictions.dense.to(self.device)
if output_embeddings is not None:
if output_embeddings is not None and self.config.tie_word_embeddings:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))