Full rework of the TF input/output embeddings and bias resizing (#9193)
* Start rework resizing * Rework bias/decoder resizing * Full resizing rework * Full resizing rework * Start to update the models with the new approach * Finish to update the models * Update all the tests * Update the template * Fix tests * Fix tests * Test a new approach * Refactoring * Refactoring * Refactoring * New rework * Rework BART * Rework bert+blenderbot * Rework CTRL * Rework Distilbert * Rework DPR * Rework Electra * Rework Flaubert * Rework Funnel * Rework GPT2 * Rework Longformer * Rework Lxmert * Rework marian+mbart * Rework mobilebert * Rework mpnet * Rework openai * Rework pegasus * Rework Roberta * Rework T5 * Rework xlm+xlnet * Rework template * Fix TFT5EncoderOnly + DPRs * Restore previous methods * Fix Funnel * Fix CTRL and TransforXL * Apply style * Apply Sylvain's comments * Restore a test in DPR * Address the comments * Fix bug * Apply style * remove unused import * Fix test * Forgot a method * missing test * Trigger CI * naming update * Rebase * Trigger CI
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
""" TF 2.0 BERT model. """
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@@ -526,6 +527,20 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.input_embeddings
|
||||
|
||||
def set_output_embeddings(self, value):
|
||||
self.input_embeddings.word_embeddings = value
|
||||
self.input_embeddings.vocab_size = shape_list(value)[0]
|
||||
|
||||
def get_bias(self):
|
||||
return {"bias": self.bias}
|
||||
|
||||
def set_bias(self, value):
|
||||
self.bias = value["bias"]
|
||||
self.vocab_size = shape_list(value["bias"])[0]
|
||||
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.input_embeddings(hidden_states, mode="linear")
|
||||
@@ -582,7 +597,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
self.embeddings.vocab_size = value.shape[0]
|
||||
self.embeddings.vocab_size = shape_list(value)[0]
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
"""
|
||||
@@ -918,13 +933,11 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
|
||||
self.nsp = TFBertNSPHead(config, name="nsp___cls")
|
||||
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.bert.embeddings
|
||||
|
||||
def get_output_layer_with_bias(self):
|
||||
def get_lm_head(self):
|
||||
return self.mlm.predictions
|
||||
|
||||
def get_prefix_bias_name(self):
|
||||
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
|
||||
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@@ -1044,13 +1057,11 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
|
||||
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.bert.embeddings
|
||||
|
||||
def get_output_layer_with_bias(self):
|
||||
def get_lm_head(self):
|
||||
return self.mlm.predictions
|
||||
|
||||
def get_prefix_bias_name(self):
|
||||
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
|
||||
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@@ -1149,13 +1160,11 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
|
||||
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.bert.embeddings
|
||||
|
||||
def get_output_layer_with_bias(self):
|
||||
def get_lm_head(self):
|
||||
return self.mlm.predictions
|
||||
|
||||
def get_prefix_bias_name(self):
|
||||
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
|
||||
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
|
||||
|
||||
@add_code_sample_docstrings(
|
||||
|
||||
Reference in New Issue
Block a user