Fix input embeddings
This commit is contained in:
@@ -426,9 +426,10 @@ class TFCommonTestCases:
|
||||
try:
|
||||
x = wte([input_ids], mode="embedding")
|
||||
except:
|
||||
x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32)
|
||||
# ^^ In our TF models, the input_embeddings can take slightly different forms,
|
||||
# so we try two of them and fall back to just synthetically creating a dummy tensor of ones.
|
||||
if hasattr(self.model_tester, "embedding_size"):
|
||||
x = tf.ones(input_ids.shape + [model.config.embedding_size], dtype=tf.dtypes.float32)
|
||||
else:
|
||||
x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32)
|
||||
inputs_dict["inputs_embeds"] = x
|
||||
outputs = model(inputs_dict)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user