[inputs_embeds] All PyTorch models
This commit is contained in:
@@ -525,6 +525,19 @@ class CommonTestCases:
|
||||
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
|
||||
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
del inputs_dict["input_ids"]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.eval()
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs_dict["inputs_embeds"] = wte(input_ids)
|
||||
outputs = model(**inputs_dict)
|
||||
|
||||
|
||||
class GPTModelTester(CommonModelTester):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user