From cf62bdc962c53d9fb7a5820217f1bf844bb6da3b Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Tue, 26 Nov 2019 14:37:32 -0500 Subject: [PATCH] Improve test protocol for inputs_embeds in TF cc @lysandrejik --- transformers/tests/modeling_tf_common_test.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/transformers/tests/modeling_tf_common_test.py b/transformers/tests/modeling_tf_common_test.py index 31a30766cf..232991915c 100644 --- a/transformers/tests/modeling_tf_common_test.py +++ b/transformers/tests/modeling_tf_common_test.py @@ -426,10 +426,15 @@ class TFCommonTestCases: try: x = wte([input_ids], mode="embedding") except: - if hasattr(self.model_tester, "embedding_size"): - x = tf.ones(input_ids.shape + [model.config.embedding_size], dtype=tf.dtypes.float32) - else: - x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32) + x = wte([input_ids, None, None, None], mode="embedding") + # ^^ In our TF models, the input_embeddings can take slightly different forms, + # so we try a few of them. + # We used to fall back to just synthetically creating a dummy tensor of ones: + # + # if hasattr(self.model_tester, "embedding_size"): + # x = tf.ones(input_ids.shape + [self.model_tester.embedding_size], dtype=tf.dtypes.float32) + # else: + # x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32) inputs_dict["inputs_embeds"] = x outputs = model(inputs_dict)