Full rework of the TF input/output embeddings and bias resizing (#9193)
* Start rework resizing * Rework bias/decoder resizing * Full resizing rework * Full resizing rework * Start to update the models with the new approach * Finish to update the models * Update all the tests * Update the template * Fix tests * Fix tests * Test a new approach * Refactoring * Refactoring * Refactoring * New rework * Rework BART * Rework bert+blenderbot * Rework CTRL * Rework Distilbert * Rework DPR * Rework Electra * Rework Flaubert * Rework Funnel * Rework GPT2 * Rework Longformer * Rework Lxmert * Rework marian+mbart * Rework mobilebert * Rework mpnet * Rework openai * Rework pegasus * Rework Roberta * Rework T5 * Rework xlm+xlnet * Rework template * Fix TFT5EncoderOnly + DPRs * Restore previous methods * Fix Funnel * Fix CTRL and TransforXL * Apply style * Apply Sylvain's comments * Restore a test in DPR * Address the comments * Fix bug * Apply style * remove unused import * Fix test * Forgot a method * missing test * Trigger CI * naming update * Rebase * Trigger CI
This commit is contained in:
@@ -15,6 +15,8 @@
|
||||
# limitations under the License.
|
||||
""" TF 2.0 CTRL model."""
|
||||
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
@@ -242,10 +244,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.w.weight = value
|
||||
self.w.vocab_size = value.shape[0]
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
raise NotImplementedError
|
||||
self.w.vocab_size = shape_list(value)[0]
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
"""
|
||||
@@ -618,6 +617,20 @@ class TFCTRLLMHead(tf.keras.layers.Layer):
|
||||
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
|
||||
super().build(input_shape)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.input_embeddings
|
||||
|
||||
def set_output_embeddings(self, value):
|
||||
self.input_embeddings.weight = value
|
||||
self.input_embeddings.vocab_size = shape_list(value)[0]
|
||||
|
||||
def get_bias(self):
|
||||
return {"bias": self.bias}
|
||||
|
||||
def set_bias(self, value):
|
||||
self.bias = value["bias"]
|
||||
self.vocab_size = shape_list(value["bias"])[0]
|
||||
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.input_embeddings(hidden_states, mode="linear")
|
||||
hidden_states = hidden_states + self.bias
|
||||
@@ -638,13 +651,11 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
|
||||
self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head.input_embeddings
|
||||
|
||||
def get_output_layer_with_bias(self):
|
||||
def get_lm_head(self):
|
||||
return self.lm_head
|
||||
|
||||
def get_prefix_bias_name(self):
|
||||
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
|
||||
return self.name + "/" + self.lm_head.name
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, past, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user