From d541938c48f759522f81fa177aae49098e0e0149 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 10 Jun 2020 18:38:34 -0400 Subject: [PATCH] Make multiple choice models work with input_embeds (#4921) --- src/transformers/modeling_bert.py | 9 ++++++-- src/transformers/modeling_longformer.py | 10 +++++++-- src/transformers/modeling_roberta.py | 11 ++++++++-- src/transformers/modeling_xlnet.py | 11 +++++++--- tests/test_modeling_common.py | 28 ++++++++++++------------- 5 files changed, 46 insertions(+), 23 deletions(-) diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index 4eb857db21..666495f87b 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -1359,12 +1359,17 @@ class BertForMultipleChoice(BertPreTrainedModel): # the linear classifier still needs to be trained loss, logits = outputs[:2] """ - num_choices = input_ids.shape[1] + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] - input_ids = input_ids.view(-1, input_ids.size(-1)) + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) outputs = self.bert( input_ids, diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index aaf33b078d..a6ba9f739e 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -1202,7 +1202,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel): loss, classification_scores = outputs[:2] """ - num_choices = input_ids.shape[1] + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] # set global attention on question tokens if global_attention_mask is None: @@ -1216,7 +1216,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel): dim=1, ) - flat_input_ids = input_ids.view(-1, input_ids.size(-1)) + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None @@ -1225,6 +1225,11 @@ class LongformerForMultipleChoice(BertPreTrainedModel): if global_attention_mask is not None else None ) + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) outputs = self.longformer( flat_input_ids, @@ -1232,6 +1237,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel): token_type_ids=flat_token_type_ids, attention_mask=flat_attention_mask, global_attention_mask=flat_global_attention_mask, + inputs_embeds=flat_inputs_embeds, output_attentions=output_attentions, ) pooled_output = outputs[1] diff --git a/src/transformers/modeling_roberta.py b/src/transformers/modeling_roberta.py index 91807c13ae..579be366bd 100644 --- a/src/transformers/modeling_roberta.py +++ b/src/transformers/modeling_roberta.py @@ -448,18 +448,25 @@ class RobertaForMultipleChoice(BertPreTrainedModel): loss, classification_scores = outputs[:2] """ - num_choices = input_ids.shape[1] + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] - flat_input_ids = input_ids.view(-1, input_ids.size(-1)) + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + outputs = self.roberta( flat_input_ids, position_ids=flat_position_ids, token_type_ids=flat_token_type_ids, attention_mask=flat_attention_mask, head_mask=head_mask, + inputs_embeds=flat_inputs_embeds, output_attentions=output_attentions, ) pooled_output = outputs[1] diff --git a/src/transformers/modeling_xlnet.py b/src/transformers/modeling_xlnet.py index 404b2433d9..9ddac6a51e 100644 --- a/src/transformers/modeling_xlnet.py +++ b/src/transformers/modeling_xlnet.py @@ -1438,12 +1438,17 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): loss, classification_scores = outputs[:2] """ - num_choices = input_ids.shape[1] + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] - flat_input_ids = input_ids.view(-1, input_ids.size(-1)) + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None flat_input_mask = input_mask.view(-1, input_mask.size(-1)) if input_mask is not None else None + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) transformer_outputs = self.transformer( flat_input_ids, @@ -1454,7 +1459,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): perm_mask=perm_mask, target_mapping=target_mapping, head_mask=head_mask, - inputs_embeds=inputs_embeds, + inputs_embeds=flat_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, ) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index ef00cd8694..ef805d91e1 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -639,31 +639,31 @@ class ModelTesterMixin: def test_inputs_embeds(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - if not self.is_encoder_decoder: - input_ids = inputs_dict["input_ids"] - del inputs_dict["input_ids"] - else: - encoder_input_ids = inputs_dict["input_ids"] - decoder_input_ids = inputs_dict.get("decoder_input_ids", encoder_input_ids) - del inputs_dict["input_ids"] - inputs_dict.pop("decoder_input_ids", None) for model_class in self.all_model_classes: - if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values(): - continue model = model_class(config) model.to(torch_device) model.eval() + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + if not self.is_encoder_decoder: + input_ids = inputs["input_ids"] + del inputs["input_ids"] + else: + encoder_input_ids = inputs["input_ids"] + decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) + del inputs["input_ids"] + inputs.pop("decoder_input_ids", None) + wte = model.get_input_embeddings() if not self.is_encoder_decoder: - inputs_dict["inputs_embeds"] = wte(input_ids) + inputs["inputs_embeds"] = wte(input_ids) else: - inputs_dict["inputs_embeds"] = wte(encoder_input_ids) - inputs_dict["decoder_inputs_embeds"] = wte(decoder_input_ids) + inputs["inputs_embeds"] = wte(encoder_input_ids) + inputs["decoder_inputs_embeds"] = wte(decoder_input_ids) with torch.no_grad(): - model(**inputs_dict) + model(**inputs) def test_lm_head_model_random_no_beam_search_generate(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()