Full rework of the TF input/output embeddings and bias resizing (#9193)
* Start rework resizing * Rework bias/decoder resizing * Full resizing rework * Full resizing rework * Start to update the models with the new approach * Finish to update the models * Update all the tests * Update the template * Fix tests * Fix tests * Test a new approach * Refactoring * Refactoring * Refactoring * New rework * Rework BART * Rework bert+blenderbot * Rework CTRL * Rework Distilbert * Rework DPR * Rework Electra * Rework Flaubert * Rework Funnel * Rework GPT2 * Rework Longformer * Rework Lxmert * Rework marian+mbart * Rework mobilebert * Rework mpnet * Rework openai * Rework pegasus * Rework Roberta * Rework T5 * Rework xlm+xlnet * Rework template * Fix TFT5EncoderOnly + DPRs * Restore previous methods * Fix Funnel * Fix CTRL and TransforXL * Apply style * Apply Sylvain's comments * Restore a test in DPR * Address the comments * Fix bug * Apply style * remove unused import * Fix test * Forgot a method * missing test * Trigger CI * naming update * Rebase * Trigger CI
This commit is contained in:
@@ -460,6 +460,20 @@ class TF{{cookiecutter.camelcase_modelname}}LMPredictionHead(tf.keras.layers.Lay
|
||||
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.input_embeddings.word_embeddings
|
||||
|
||||
def set_output_embeddings(self, value):
|
||||
self.input_embeddings.word_embeddings = value
|
||||
self.input_embeddings.vocab_size = shape_list(value)[0]
|
||||
|
||||
def get_bias(self):
|
||||
return {"bias": self.bias}
|
||||
|
||||
def set_bias(self, value):
|
||||
self.bias = value["bias"]
|
||||
self.vocab_size = shape_list(value["bias"])[0]
|
||||
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
@@ -800,15 +814,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
|
||||
self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}")
|
||||
self.mlm = TF{{cookiecutter.camelcase_modelname}}MLMHead(config, self.{{cookiecutter.lowercase_modelname}}.embeddings, name="mlm___cls")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.{{cookiecutter.lowercase_modelname}}.embeddings
|
||||
|
||||
def get_output_layer_with_bias(self):
|
||||
def get_lm_head(self):
|
||||
return self.mlm.predictions
|
||||
|
||||
def get_prefix_bias_name(self):
|
||||
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
|
||||
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
@@ -903,15 +911,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
|
||||
self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}")
|
||||
self.mlm = TF{{cookiecutter.camelcase_modelname}}MLMHead(config, self.{{cookiecutter.lowercase_modelname}}.embeddings, name="mlm___cls")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.{{cookiecutter.lowercase_modelname}}.embeddings
|
||||
|
||||
def get_output_layer_with_bias(self):
|
||||
def get_lm_head(self):
|
||||
return self.mlm.predictions
|
||||
|
||||
def get_prefix_bias_name(self):
|
||||
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
|
||||
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||
@@ -1855,6 +1857,29 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
|
||||
}
|
||||
return dummy_inputs
|
||||
|
||||
def get_input_embeddings(self):
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
|
||||
return base_model.shared
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
|
||||
try:
|
||||
base_model.shared.weight = value
|
||||
except AttributeError:
|
||||
self(self.dummy_inputs)
|
||||
base_model.shared.weight = value
|
||||
|
||||
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
|
||||
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
|
||||
base_model.encoder.set_embed_tokens(embed_tokens)
|
||||
base_model.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
{
|
||||
@@ -1984,6 +2009,9 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||
self.layers = [TF{{cookiecutter.camelcase_modelname}}EncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
|
||||
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
|
||||
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
@@ -2124,6 +2152,9 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
||||
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
@@ -2331,6 +2362,9 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
self.encoder = TF{{cookiecutter.camelcase_modelname}}Encoder(config, embed_tokens, name="encoder")
|
||||
self.decoder = TF{{cookiecutter.camelcase_modelname}}Decoder(config, embed_tokens, name="decoder")
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
|
||||
@@ -2452,15 +2486,6 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
encoder_attentions=enc_attns,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.shared = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The {{cookiecutter.uppercase_modelname}} Model with a language modeling head. Can be used for summarization.",
|
||||
@@ -2483,23 +2508,21 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.decoder
|
||||
|
||||
def get_encoder(self):
|
||||
return self.model.encoder
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens):
|
||||
super().resize_token_embeddings(new_num_tokens=new_num_tokens)
|
||||
def get_bias(self):
|
||||
return {"final_logits_bias": self.final_logits_bias}
|
||||
|
||||
# {{cookiecutter.uppercase_modelname}} is a special case where the bias has two dimensions
|
||||
# and not named just `bias`
|
||||
if new_num_tokens is not None:
|
||||
num_tokens_to_copy = min(shape_list(self.final_logits_bias)[0], new_num_tokens)
|
||||
init_bias = tf.zeros((new_num_tokens,))
|
||||
init_bias[:num_tokens_to_copy] = self.final_logits_bias.value()[:num_tokens_to_copy]
|
||||
self.final_logits_bias = self.add_weight(
|
||||
shape=(1, new_num_tokens),
|
||||
initializer="zeros",
|
||||
trainable=False,
|
||||
name="final_logits_bias",
|
||||
)
|
||||
self.final_logits_bias.assign(init_bias)
|
||||
def set_bias(self, value):
|
||||
self.final_logits_bias = value["final_logits_bias"]
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.get_input_embeddings()
|
||||
|
||||
def set_output_embeddings(self, value):
|
||||
self.set_input_embeddings(value)
|
||||
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@@ -2664,12 +2687,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
)
|
||||
return (past[0], reordered_past)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.model.shared
|
||||
|
||||
def get_encoder(self):
|
||||
return self.model.encoder
|
||||
|
||||
def compute_loss(self, labels, logits):
|
||||
"""CrossEntropyLoss that ignores pad tokens"""
|
||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||
|
||||
@@ -486,10 +486,82 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTest(TFModelTesterMixin, unitte
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
||||
x = model.get_output_layer_with_bias()
|
||||
assert x is None
|
||||
name = model.get_prefix_bias_name()
|
||||
assert name is None
|
||||
|
||||
if model_class in self.all_generative_model_classes:
|
||||
x = model.get_output_embeddings()
|
||||
assert isinstance(x, tf.keras.layers.Layer)
|
||||
name = model.get_bias()
|
||||
assert isinstance(name, dict)
|
||||
for k, v in name.items():
|
||||
assert isinstance(v, tf.Variable)
|
||||
else:
|
||||
x = model.get_output_embeddings()
|
||||
assert x is None
|
||||
name = model.get_bias()
|
||||
assert name is None
|
||||
|
||||
def test_resize_token_embeddings(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def _get_word_embedding_weight(model, embedding_layer):
|
||||
if hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
else:
|
||||
# Here we build the word embeddings weights if not exists.
|
||||
# And then we retry to get the attribute once built.
|
||||
model(model.dummy_inputs)
|
||||
if hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
else:
|
||||
return None
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
|
||||
# build the embeddings
|
||||
model = model_class(config=config)
|
||||
old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
old_final_logits_bias = model.get_bias()
|
||||
|
||||
# reshape the embeddings
|
||||
model.resize_token_embeddings(size)
|
||||
new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
new_final_logits_bias = model.get_bias()
|
||||
|
||||
# check that the resized embeddings size matches the desired size.
|
||||
assert_size = size if size is not None else config.vocab_size
|
||||
|
||||
self.assertEqual(new_input_embeddings.shape[0], assert_size)
|
||||
|
||||
# check that weights remain the same after resizing
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
if old_output_embeddings is not None and new_output_embeddings is not None:
|
||||
self.assertEqual(new_output_embeddings.shape[0], assert_size)
|
||||
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
if old_final_logits_bias is not None and new_final_logits_bias is not None:
|
||||
old_final_logits_bias = old_final_logits_bias["final_logits_bias"]
|
||||
new_final_logits_bias = new_final_logits_bias["final_logits_bias"]
|
||||
self.assertEqual(new_final_logits_bias.shape[0], 1)
|
||||
self.assertEqual(new_final_logits_bias.shape[1], assert_size)
|
||||
|
||||
models_equal = True
|
||||
for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()):
|
||||
for p1, p2 in zip(old, new):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
|
||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||
|
||||
Reference in New Issue
Block a user