RFC: Replace custom TF embeddings by Keras embeddings (#18939)

This commit is contained in:
Joao Gante
2022-09-10 11:34:49 +01:00
committed by GitHub
parent 855dcae8bb
commit 00cbadb870
5 changed files with 141 additions and 142 deletions

View File

@@ -1144,30 +1144,20 @@ class TFModelTesterMixin:
self.assert_outputs_same(output_for_dict_input, output_for_kw_input)
def test_resize_token_embeddings(self):
# TODO (joao): after the embeddings refactor is complete, rework this test so as to rely exclusively on
# tf.keras.layers.Embedding
if not self.test_resize_embeddings:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def _get_word_embedding_weight(model, embedding_layer):
embeds = getattr(embedding_layer, "weight", None)
if embeds is not None:
return embeds
embeds = getattr(embedding_layer, "decoder", None)
if embeds is not None:
return embeds
model(model.dummy_inputs)
embeds = getattr(embedding_layer, "weight", None)
if embeds is not None:
return embeds
embeds = getattr(embedding_layer, "decoder", None)
if embeds is not None:
return embeds
return None
if isinstance(embedding_layer, tf.keras.layers.Embedding):
# builds the embeddings layer
model(model.dummy_inputs)
return embedding_layer.embeddings
else:
return model._get_word_embedding_weight(embedding_layer)
for model_class in self.all_model_classes:
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
@@ -1195,10 +1185,10 @@ class TFModelTesterMixin:
if old_bias is not None and new_bias is not None:
for old_weight, new_weight in zip(old_bias.values(), new_bias.values()):
self.assertEqual(new_weight.shape[0], assert_size)
self.assertEqual(new_weight.shape[-1], assert_size)
models_equal = True
for p1, p2 in zip(old_weight.value(), new_weight.value()):
for p1, p2 in zip(tf.squeeze(old_weight), tf.squeeze(new_weight)):
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
models_equal = False
self.assertTrue(models_equal)