From 8742baa53136b906e392fdae57f1c191d1b2370e Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Tue, 26 Nov 2019 14:39:47 -0500 Subject: [PATCH] Improve test protocol for inputs_embeds in TF --- transformers/tests/modeling_tf_common_test.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/transformers/tests/modeling_tf_common_test.py b/transformers/tests/modeling_tf_common_test.py index 232991915c..ea8cd1aecd 100644 --- a/transformers/tests/modeling_tf_common_test.py +++ b/transformers/tests/modeling_tf_common_test.py @@ -426,15 +426,17 @@ class TFCommonTestCases: try: x = wte([input_ids], mode="embedding") except: - x = wte([input_ids, None, None, None], mode="embedding") + try: + x = wte([input_ids, None, None, None], mode="embedding") + except: + if hasattr(self.model_tester, "embedding_size"): + x = tf.ones(input_ids.shape + [self.model_tester.embedding_size], dtype=tf.dtypes.float32) + else: + x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32) # ^^ In our TF models, the input_embeddings can take slightly different forms, # so we try a few of them. # We used to fall back to just synthetically creating a dummy tensor of ones: # - # if hasattr(self.model_tester, "embedding_size"): - # x = tf.ones(input_ids.shape + [self.model_tester.embedding_size], dtype=tf.dtypes.float32) - # else: - # x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32) inputs_dict["inputs_embeds"] = x outputs = model(inputs_dict)