Full rework of the TF input/output embeddings and bias resizing (#9193)

* Start rework resizing

* Rework bias/decoder resizing

* Full resizing rework

* Full resizing rework

* Start to update the models with the new approach

* Finish to update the models

* Update all the tests

* Update the template

* Fix tests

* Fix tests

* Test a new approach

* Refactoring

* Refactoring

* Refactoring

* New rework

* Rework BART

* Rework bert+blenderbot

* Rework CTRL

* Rework Distilbert

* Rework DPR

* Rework Electra

* Rework Flaubert

* Rework Funnel

* Rework GPT2

* Rework Longformer

* Rework Lxmert

* Rework marian+mbart

* Rework mobilebert

* Rework mpnet

* Rework openai

* Rework pegasus

* Rework Roberta

* Rework T5

* Rework xlm+xlnet

* Rework template

* Fix TFT5EncoderOnly + DPRs

* Restore previous methods

* Fix Funnel

* Fix CTRL and TransforXL

* Apply style

* Apply Sylvain's comments

* Restore a test in DPR

* Address the comments

* Fix bug

* Apply style

* remove unused import

* Fix test

* Forgot a method

* missing test

* Trigger CI

* naming update

* Rebase

* Trigger CI
This commit is contained in:
Julien Plu
2021-01-11 12:27:28 +01:00
committed by GitHub
parent cf416764f4
commit 1243ee7d0c
40 changed files with 1473 additions and 593 deletions

View File

@@ -17,6 +17,7 @@
import math
import warnings
import tensorflow as tf
@@ -541,7 +542,7 @@ class TFMPNetMainLayer(tf.keras.layers.Layer):
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
self.embeddings.vocab_size = shape_list(value)[0]
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads
def _prune_heads(self, heads_to_prune):
@@ -840,6 +841,20 @@ class TFMPNetLMHead(tf.keras.layers.Layer):
super().build(input_shape)
def get_output_embeddings(self):
return self.decoder
def set_output_embeddings(self, value):
self.decoder.word_embeddings = value
self.decoder.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]
def call(self, features):
x = self.dense(features)
x = self.act(x)
@@ -862,13 +877,11 @@ class TFMPNetForMaskedLM(TFMPNetPreTrainedModel, TFMaskedLanguageModelingLoss):
self.mpnet = TFMPNetMainLayer(config, name="mpnet")
self.lm_head = TFMPNetLMHead(config, self.mpnet.embeddings, name="lm_head")
def get_output_embeddings(self):
return self.mpnet.embeddings
def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.lm_head
def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name
@add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))