Fix embeddings resizing in TF models (#8657)

* Resize the biases in same time than the embeddings

* Trigger CI

* Biases are not reset anymore

* Remove get_output_embeddings + better LM model detection in generation utils

* Apply style

* First test on BERT

* Update docstring + new name

* Apply the new resizing logic to all the models

* fix tests

* Apply style

* Update the template

* Fix naming

* Fix naming

* Apply style

* Apply style

* Remove unused import

* Revert get_output_embeddings

* Trigger CI

* Update num parameters

* Restore get_output_embeddings in TFPretrainedModel and add comments

* Style

* Add decoder resizing

* Style

* Fix tests

* Separate bias and decoder resize

* Fix tests

* Fix tests

* Apply style

* Add bias resizing in MPNet

* Trigger CI

* Apply style
This commit is contained in:
Julien Plu
2020-12-14 05:05:24 +01:00
committed by GitHub
parent 3552d0e0d8
commit 51d9c569fa
31 changed files with 470 additions and 18 deletions

View File

@@ -893,6 +893,12 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
def get_output_embeddings(self):
return self.bert.embeddings
def get_output_layer_with_bias(self):
return self.mlm.predictions
def get_prefix_bias_name(self):
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call(
@@ -1002,6 +1008,12 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
def get_output_embeddings(self):
return self.bert.embeddings
def get_output_layer_with_bias(self):
return self.mlm.predictions
def get_prefix_bias_name(self):
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
@@ -1095,6 +1107,12 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
def get_output_embeddings(self):
return self.bert.embeddings
def get_output_layer_with_bias(self):
return self.mlm.predictions
def get_prefix_bias_name(self):
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-cased",