Make multiple choice models work with input_embeds (#4921)

This commit is contained in:
Sylvain Gugger
2020-06-10 18:38:34 -04:00
committed by GitHub
parent 1e2631d6f8
commit d541938c48
5 changed files with 46 additions and 23 deletions

View File

@@ -639,31 +639,31 @@ class ModelTesterMixin:
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.is_encoder_decoder:
input_ids = inputs_dict["input_ids"]
del inputs_dict["input_ids"]
else:
encoder_input_ids = inputs_dict["input_ids"]
decoder_input_ids = inputs_dict.get("decoder_input_ids", encoder_input_ids)
del inputs_dict["input_ids"]
inputs_dict.pop("decoder_input_ids", None)
for model_class in self.all_model_classes:
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
continue
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
if not self.is_encoder_decoder:
input_ids = inputs["input_ids"]
del inputs["input_ids"]
else:
encoder_input_ids = inputs["input_ids"]
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
del inputs["input_ids"]
inputs.pop("decoder_input_ids", None)
wte = model.get_input_embeddings()
if not self.is_encoder_decoder:
inputs_dict["inputs_embeds"] = wte(input_ids)
inputs["inputs_embeds"] = wte(input_ids)
else:
inputs_dict["inputs_embeds"] = wte(encoder_input_ids)
inputs_dict["decoder_inputs_embeds"] = wte(decoder_input_ids)
inputs["inputs_embeds"] = wte(encoder_input_ids)
inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
with torch.no_grad():
model(**inputs_dict)
model(**inputs)
def test_lm_head_model_random_no_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()