Full rework of the TF input/output embeddings and bias resizing (#9193)
* Start rework resizing * Rework bias/decoder resizing * Full resizing rework * Full resizing rework * Start to update the models with the new approach * Finish to update the models * Update all the tests * Update the template * Fix tests * Fix tests * Test a new approach * Refactoring * Refactoring * Refactoring * New rework * Rework BART * Rework bert+blenderbot * Rework CTRL * Rework Distilbert * Rework DPR * Rework Electra * Rework Flaubert * Rework Funnel * Rework GPT2 * Rework Longformer * Rework Lxmert * Rework marian+mbart * Rework mobilebert * Rework mpnet * Rework openai * Rework pegasus * Rework Roberta * Rework T5 * Rework xlm+xlnet * Rework template * Fix TFT5EncoderOnly + DPRs * Restore previous methods * Fix Funnel * Fix CTRL and TransforXL * Apply style * Apply Sylvain's comments * Restore a test in DPR * Address the comments * Fix bug * Apply style * remove unused import * Fix test * Forgot a method * missing test * Trigger CI * naming update * Rebase * Trigger CI
This commit is contained in:
@@ -41,7 +41,6 @@ if is_tf_available():
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
TFAdaptiveEmbedding,
|
||||
TFSharedEmbeddings,
|
||||
tf_top_k_top_p_filtering,
|
||||
)
|
||||
@@ -671,18 +670,20 @@ class TFModelTesterMixin:
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
assert isinstance(model.get_input_embeddings(), (tf.keras.layers.Layer, TFAdaptiveEmbedding))
|
||||
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
||||
|
||||
if model_class in list_lm_models:
|
||||
x = model.get_output_layer_with_bias()
|
||||
x = model.get_output_embeddings()
|
||||
assert isinstance(x, tf.keras.layers.Layer)
|
||||
name = model.get_prefix_bias_name()
|
||||
assert isinstance(name, str)
|
||||
name = model.get_bias()
|
||||
assert isinstance(name, dict)
|
||||
for k, v in name.items():
|
||||
assert isinstance(v, tf.Variable)
|
||||
else:
|
||||
x = model.get_output_layer_with_bias()
|
||||
assert x is None
|
||||
name = model.get_prefix_bias_name()
|
||||
x = model.get_output_embeddings()
|
||||
assert x is None
|
||||
name = model.get_bias()
|
||||
assert name is None
|
||||
|
||||
def test_determinism(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -830,26 +831,71 @@ class TFModelTesterMixin:
|
||||
if not self.test_resize_embeddings:
|
||||
return
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
INPUT_SHAPE = [1, 10, config.hidden_size]
|
||||
|
||||
def _get_word_embedding_weight(model, embedding_layer):
|
||||
if hasattr(embedding_layer, "word_embeddings"):
|
||||
return embedding_layer.word_embeddings
|
||||
elif hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
elif hasattr(embedding_layer, "decoder"):
|
||||
return embedding_layer.decoder
|
||||
else:
|
||||
# Here we build the word embeddings weights if not exists.
|
||||
# And then we retry to get the attribute once built.
|
||||
model(model.dummy_inputs)
|
||||
if hasattr(embedding_layer, "word_embeddings"):
|
||||
return embedding_layer.word_embeddings
|
||||
elif hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
elif hasattr(embedding_layer, "decoder"):
|
||||
return embedding_layer.decoder
|
||||
else:
|
||||
return None
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
|
||||
# build the embeddings
|
||||
model = model_class(config=config)
|
||||
emb_old = model.get_input_embeddings()
|
||||
emb_old.build(INPUT_SHAPE)
|
||||
old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
old_bias = model.get_bias()
|
||||
old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
# reshape the embeddings
|
||||
new_embeddings = model._get_resized_embeddings(emb_old, size)
|
||||
# # check that the resized embeddings size matches the desired size.
|
||||
model.resize_token_embeddings(size)
|
||||
new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
new_bias = model.get_bias()
|
||||
new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
|
||||
# check that the resized embeddings size matches the desired size.
|
||||
assert_size = size if size is not None else config.vocab_size
|
||||
self.assertEqual(new_embeddings.shape[0], assert_size)
|
||||
self.assertEqual(new_input_embeddings.shape[0], assert_size)
|
||||
|
||||
# check that weights remain the same after resizing
|
||||
emd_old_weights = model._get_word_embeddings(emb_old)
|
||||
models_equal = True
|
||||
for p1, p2 in zip(emd_old_weights.numpy(), new_embeddings.numpy()):
|
||||
if np.sum(abs(p1 - p2)) > 0:
|
||||
for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
if old_bias is not None and new_bias is not None:
|
||||
for old_weight, new_weight in zip(old_bias.values(), new_bias.values()):
|
||||
self.assertEqual(new_weight.shape[0], assert_size)
|
||||
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_weight.value(), new_weight.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
if old_output_embeddings is not None and new_output_embeddings is not None:
|
||||
self.assertEqual(new_output_embeddings.shape[0], assert_size)
|
||||
self.assertEqual(new_output_embeddings.shape[1], old_output_embeddings.shape[1])
|
||||
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
|
||||
Reference in New Issue
Block a user