Make multiple choice models work with input_embeds (#4921)
This commit is contained in:
@@ -1359,12 +1359,17 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
|||||||
# the linear classifier still needs to be trained
|
# the linear classifier still needs to be trained
|
||||||
loss, logits = outputs[:2]
|
loss, logits = outputs[:2]
|
||||||
"""
|
"""
|
||||||
num_choices = input_ids.shape[1]
|
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
||||||
|
|
||||||
input_ids = input_ids.view(-1, input_ids.size(-1))
|
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
||||||
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||||
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||||
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||||
|
inputs_embeds = (
|
||||||
|
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
||||||
|
if inputs_embeds is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
outputs = self.bert(
|
outputs = self.bert(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
|||||||
@@ -1202,7 +1202,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
|
|||||||
loss, classification_scores = outputs[:2]
|
loss, classification_scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
num_choices = input_ids.shape[1]
|
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
||||||
|
|
||||||
# set global attention on question tokens
|
# set global attention on question tokens
|
||||||
if global_attention_mask is None:
|
if global_attention_mask is None:
|
||||||
@@ -1216,7 +1216,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
|
|||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
||||||
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||||
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||||
@@ -1225,6 +1225,11 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
|
|||||||
if global_attention_mask is not None
|
if global_attention_mask is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
flat_inputs_embeds = (
|
||||||
|
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
||||||
|
if inputs_embeds is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
outputs = self.longformer(
|
outputs = self.longformer(
|
||||||
flat_input_ids,
|
flat_input_ids,
|
||||||
@@ -1232,6 +1237,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
|
|||||||
token_type_ids=flat_token_type_ids,
|
token_type_ids=flat_token_type_ids,
|
||||||
attention_mask=flat_attention_mask,
|
attention_mask=flat_attention_mask,
|
||||||
global_attention_mask=flat_global_attention_mask,
|
global_attention_mask=flat_global_attention_mask,
|
||||||
|
inputs_embeds=flat_inputs_embeds,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
pooled_output = outputs[1]
|
pooled_output = outputs[1]
|
||||||
|
|||||||
@@ -448,18 +448,25 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
|
|||||||
loss, classification_scores = outputs[:2]
|
loss, classification_scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
num_choices = input_ids.shape[1]
|
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
||||||
|
|
||||||
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
||||||
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||||
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||||
|
flat_inputs_embeds = (
|
||||||
|
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
||||||
|
if inputs_embeds is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
outputs = self.roberta(
|
outputs = self.roberta(
|
||||||
flat_input_ids,
|
flat_input_ids,
|
||||||
position_ids=flat_position_ids,
|
position_ids=flat_position_ids,
|
||||||
token_type_ids=flat_token_type_ids,
|
token_type_ids=flat_token_type_ids,
|
||||||
attention_mask=flat_attention_mask,
|
attention_mask=flat_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=flat_inputs_embeds,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
pooled_output = outputs[1]
|
pooled_output = outputs[1]
|
||||||
|
|||||||
@@ -1438,12 +1438,17 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
|
|||||||
loss, classification_scores = outputs[:2]
|
loss, classification_scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
num_choices = input_ids.shape[1]
|
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
||||||
|
|
||||||
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
||||||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||||
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||||
flat_input_mask = input_mask.view(-1, input_mask.size(-1)) if input_mask is not None else None
|
flat_input_mask = input_mask.view(-1, input_mask.size(-1)) if input_mask is not None else None
|
||||||
|
flat_inputs_embeds = (
|
||||||
|
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
||||||
|
if inputs_embeds is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
flat_input_ids,
|
flat_input_ids,
|
||||||
@@ -1454,7 +1459,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
|
|||||||
perm_mask=perm_mask,
|
perm_mask=perm_mask,
|
||||||
target_mapping=target_mapping,
|
target_mapping=target_mapping,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=flat_inputs_embeds,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -639,31 +639,31 @@ class ModelTesterMixin:
|
|||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
if not self.is_encoder_decoder:
|
|
||||||
input_ids = inputs_dict["input_ids"]
|
|
||||||
del inputs_dict["input_ids"]
|
|
||||||
else:
|
|
||||||
encoder_input_ids = inputs_dict["input_ids"]
|
|
||||||
decoder_input_ids = inputs_dict.get("decoder_input_ids", encoder_input_ids)
|
|
||||||
del inputs_dict["input_ids"]
|
|
||||||
inputs_dict.pop("decoder_input_ids", None)
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
|
||||||
continue
|
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
if not self.is_encoder_decoder:
|
||||||
|
input_ids = inputs["input_ids"]
|
||||||
|
del inputs["input_ids"]
|
||||||
|
else:
|
||||||
|
encoder_input_ids = inputs["input_ids"]
|
||||||
|
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
|
||||||
|
del inputs["input_ids"]
|
||||||
|
inputs.pop("decoder_input_ids", None)
|
||||||
|
|
||||||
wte = model.get_input_embeddings()
|
wte = model.get_input_embeddings()
|
||||||
if not self.is_encoder_decoder:
|
if not self.is_encoder_decoder:
|
||||||
inputs_dict["inputs_embeds"] = wte(input_ids)
|
inputs["inputs_embeds"] = wte(input_ids)
|
||||||
else:
|
else:
|
||||||
inputs_dict["inputs_embeds"] = wte(encoder_input_ids)
|
inputs["inputs_embeds"] = wte(encoder_input_ids)
|
||||||
inputs_dict["decoder_inputs_embeds"] = wte(decoder_input_ids)
|
inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model(**inputs_dict)
|
model(**inputs)
|
||||||
|
|
||||||
def test_lm_head_model_random_no_beam_search_generate(self):
|
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
Reference in New Issue
Block a user