Fix embeddings resizing in TF models (#8657)
* Resize the biases in same time than the embeddings * Trigger CI * Biases are not reset anymore * Remove get_output_embeddings + better LM model detection in generation utils * Apply style * First test on BERT * Update docstring + new name * Apply the new resizing logic to all the models * fix tests * Apply style * Update the template * Fix naming * Fix naming * Apply style * Apply style * Remove unused import * Revert get_output_embeddings * Trigger CI * Update num parameters * Restore get_output_embeddings in TFPretrainedModel and add comments * Style * Add decoder resizing * Style * Fix tests * Separate bias and decoder resize * Fix tests * Fix tests * Apply style * Add bias resizing in MPNet * Trigger CI * Apply style
This commit is contained in:
@@ -893,6 +893,12 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
|
||||
def get_output_embeddings(self):
|
||||
return self.bert.embeddings
|
||||
|
||||
def get_output_layer_with_bias(self):
|
||||
return self.mlm.predictions
|
||||
|
||||
def get_prefix_bias_name(self):
|
||||
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
@@ -1002,6 +1008,12 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
def get_output_embeddings(self):
|
||||
return self.bert.embeddings
|
||||
|
||||
def get_output_layer_with_bias(self):
|
||||
return self.mlm.predictions
|
||||
|
||||
def get_prefix_bias_name(self):
|
||||
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
@@ -1095,6 +1107,12 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
def get_output_embeddings(self):
|
||||
return self.bert.embeddings
|
||||
|
||||
def get_output_layer_with_bias(self):
|
||||
return self.mlm.predictions
|
||||
|
||||
def get_prefix_bias_name(self):
|
||||
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
|
||||
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="bert-base-cased",
|
||||
|
||||
Reference in New Issue
Block a user