Add "tie_word_embeddings" config param (#6692)
* add tie_word_embeddings * correct word embeddings in modeling utils * make style * make config param only relevant for torch * make style * correct typo * delete deprecated arg in transo-xl
This commit is contained in:
committed by
GitHub
parent
fa8ee8e855
commit
925f34bbbd
@@ -945,7 +945,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
|
||||
self.cls.predictions.dense = resized_dense
|
||||
self.cls.predictions.dense.to(self.device)
|
||||
|
||||
if output_embeddings is not None:
|
||||
if output_embeddings is not None and self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||
|
||||
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
|
||||
@@ -1060,7 +1060,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
|
||||
self.cls.predictions.dense = resized_dense
|
||||
self.cls.predictions.dense.to(self.device)
|
||||
|
||||
if output_embeddings is not None:
|
||||
if output_embeddings is not None and self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||
|
||||
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
|
||||
Reference in New Issue
Block a user