Full rework of the TF input/output embeddings and bias resizing (#9193)
* Start rework resizing * Rework bias/decoder resizing * Full resizing rework * Full resizing rework * Start to update the models with the new approach * Finish to update the models * Update all the tests * Update the template * Fix tests * Fix tests * Test a new approach * Refactoring * Refactoring * Refactoring * New rework * Rework BART * Rework bert+blenderbot * Rework CTRL * Rework Distilbert * Rework DPR * Rework Electra * Rework Flaubert * Rework Funnel * Rework GPT2 * Rework Longformer * Rework Lxmert * Rework marian+mbart * Rework mobilebert * Rework mpnet * Rework openai * Rework pegasus * Rework Roberta * Rework T5 * Rework xlm+xlnet * Rework template * Fix TFT5EncoderOnly + DPRs * Restore previous methods * Fix Funnel * Fix CTRL and TransforXL * Apply style * Apply Sylvain's comments * Restore a test in DPR * Address the comments * Fix bug * Apply style * remove unused import * Fix test * Forgot a method * missing test * Trigger CI * naming update * Rebase * Trigger CI
This commit is contained in:
@@ -41,7 +41,7 @@ class ModelTester(TFBartModelTester):
|
||||
|
||||
|
||||
@require_tf
|
||||
class TestTFPegasusCommon(TFModelTesterMixin, unittest.TestCase):
|
||||
class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
|
||||
model_tester_cls = ModelTester
|
||||
@@ -59,14 +59,6 @@ class TestTFPegasusCommon(TFModelTesterMixin, unittest.TestCase):
|
||||
# inputs_embeds not supported
|
||||
pass
|
||||
|
||||
def test_saved_model_with_hidden_states_output(self):
|
||||
# Should be uncommented during patrick TF refactor
|
||||
pass
|
||||
|
||||
def test_saved_model_with_attentions_output(self):
|
||||
# Should be uncommented during patrick TF refactor
|
||||
pass
|
||||
|
||||
def test_compile_tf_model(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -104,15 +96,87 @@ class TestTFPegasusCommon(TFModelTesterMixin, unittest.TestCase):
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
||||
x = model.get_output_layer_with_bias()
|
||||
assert x is None
|
||||
name = model.get_prefix_bias_name()
|
||||
assert name is None
|
||||
|
||||
if model_class in self.all_generative_model_classes:
|
||||
x = model.get_output_embeddings()
|
||||
assert isinstance(x, tf.keras.layers.Layer)
|
||||
name = model.get_bias()
|
||||
assert isinstance(name, dict)
|
||||
for k, v in name.items():
|
||||
assert isinstance(v, tf.Variable)
|
||||
else:
|
||||
x = model.get_output_embeddings()
|
||||
assert x is None
|
||||
name = model.get_bias()
|
||||
assert name is None
|
||||
|
||||
def test_saved_model_creation(self):
|
||||
# This test is too long (>30sec) and makes fail the CI
|
||||
pass
|
||||
|
||||
def test_resize_token_embeddings(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def _get_word_embedding_weight(model, embedding_layer):
|
||||
if hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
else:
|
||||
# Here we build the word embeddings weights if not exists.
|
||||
# And then we retry to get the attribute once built.
|
||||
model(model.dummy_inputs)
|
||||
if hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
else:
|
||||
return None
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
|
||||
# build the embeddings
|
||||
model = model_class(config=config)
|
||||
old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
old_final_logits_bias = model.get_bias()
|
||||
|
||||
# reshape the embeddings
|
||||
model.resize_token_embeddings(size)
|
||||
new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
new_final_logits_bias = model.get_bias()
|
||||
|
||||
# check that the resized embeddings size matches the desired size.
|
||||
assert_size = size if size is not None else config.vocab_size
|
||||
|
||||
self.assertEqual(new_input_embeddings.shape[0], assert_size)
|
||||
|
||||
# check that weights remain the same after resizing
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
if old_output_embeddings is not None and new_output_embeddings is not None:
|
||||
self.assertEqual(new_output_embeddings.shape[0], assert_size)
|
||||
|
||||
models_equal = True
|
||||
for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
if old_final_logits_bias is not None and new_final_logits_bias is not None:
|
||||
old_final_logits_bias = old_final_logits_bias["final_logits_bias"]
|
||||
new_final_logits_bias = new_final_logits_bias["final_logits_bias"]
|
||||
self.assertEqual(new_final_logits_bias.shape[0], 1)
|
||||
self.assertEqual(new_final_logits_bias.shape[1], assert_size)
|
||||
|
||||
models_equal = True
|
||||
for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()):
|
||||
for p1, p2 in zip(old, new):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
@require_sentencepiece
|
||||
|
||||
Reference in New Issue
Block a user