From f2f329408db66285fd59e6628ca394381bb7f94e Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 26 Nov 2019 12:59:28 -0500 Subject: [PATCH] Fix input embeddings --- transformers/tests/modeling_albert_test.py | 2 ++ transformers/tests/modeling_tf_albert_test.py | 2 ++ transformers/tests/modeling_tf_common_test.py | 7 ++++--- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/transformers/tests/modeling_albert_test.py b/transformers/tests/modeling_albert_test.py index da87709df1..976feff9db 100644 --- a/transformers/tests/modeling_albert_test.py +++ b/transformers/tests/modeling_albert_test.py @@ -49,6 +49,7 @@ class AlbertModelTest(CommonTestCases.CommonModelTester): use_token_type_ids=True, use_labels=True, vocab_size=99, + embedding_size=16, hidden_size=36, num_hidden_layers=6, num_hidden_groups=6, @@ -73,6 +74,7 @@ class AlbertModelTest(CommonTestCases.CommonModelTester): self.use_token_type_ids = use_token_type_ids self.use_labels = use_labels self.vocab_size = vocab_size + self.embedding_size = embedding_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads diff --git a/transformers/tests/modeling_tf_albert_test.py b/transformers/tests/modeling_tf_albert_test.py index 85fc62f34f..fbd519b8f6 100644 --- a/transformers/tests/modeling_tf_albert_test.py +++ b/transformers/tests/modeling_tf_albert_test.py @@ -54,6 +54,7 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester): use_token_type_ids=True, use_labels=True, vocab_size=99, + embedding_size=16, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, @@ -77,6 +78,7 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester): self.use_token_type_ids = use_token_type_ids self.use_labels = use_labels self.vocab_size = vocab_size + self.embedding_size = embedding_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads diff --git a/transformers/tests/modeling_tf_common_test.py b/transformers/tests/modeling_tf_common_test.py index 2bb7cc9c5f..31a30766cf 100644 --- a/transformers/tests/modeling_tf_common_test.py +++ b/transformers/tests/modeling_tf_common_test.py @@ -426,9 +426,10 @@ class TFCommonTestCases: try: x = wte([input_ids], mode="embedding") except: - x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32) - # ^^ In our TF models, the input_embeddings can take slightly different forms, - # so we try two of them and fall back to just synthetically creating a dummy tensor of ones. + if hasattr(self.model_tester, "embedding_size"): + x = tf.ones(input_ids.shape + [model.config.embedding_size], dtype=tf.dtypes.float32) + else: + x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32) inputs_dict["inputs_embeds"] = x outputs = model(inputs_dict)