[inputs_embeds] All TF models + tests
This commit is contained in:
@@ -131,10 +131,6 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
def create_and_check_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = TFBertModel(config=config)
|
||||
# inputs = {'input_ids': input_ids,
|
||||
# 'attention_mask': input_mask,
|
||||
# 'token_type_ids': token_type_ids}
|
||||
# sequence_output, pooled_output = model(**inputs)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
|
||||
@@ -411,6 +411,27 @@ class TFCommonTestCases:
|
||||
first, second = model(inputs_dict, training=False)[0], model(inputs_dict, training=False)[0]
|
||||
self.assertTrue(tf.math.equal(first, second).numpy().all())
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
del inputs_dict["input_ids"]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
try:
|
||||
x = wte(input_ids, mode="embedding")
|
||||
except:
|
||||
try:
|
||||
x = wte([input_ids], mode="embedding")
|
||||
except:
|
||||
x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32)
|
||||
# ^^ In our TF models, the input_embeddings can take slightly different forms,
|
||||
# so we try two of them and fall back to just synthetically creating a dummy tensor of ones.
|
||||
inputs_dict["inputs_embeds"] = x
|
||||
outputs = model(inputs_dict)
|
||||
|
||||
|
||||
def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
|
||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||
|
||||
Reference in New Issue
Block a user