add multiple choice to robreta and xlnet, test on swag, roberta=0.82.28
, xlnet=0.80
This commit is contained in:
@@ -329,6 +329,46 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
|
||||
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
class RobertaForMultipleChoice(BertPreTrainedModel):
|
||||
config_class = RobertaConfig
|
||||
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
def __init__(self, config):
|
||||
super(RobertaForMultipleChoice, self).__init__(config)
|
||||
|
||||
self.roberta = RobertaModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
|
||||
position_ids=None, head_mask=None):
|
||||
num_choices = input_ids.shape[1]
|
||||
|
||||
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||
outputs = self.roberta(flat_input_ids, position_ids=flat_position_ids, token_type_ids=flat_token_type_ids,
|
||||
attention_mask=flat_attention_mask, head_mask=head_mask)
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = logits.view(-1, num_choices)
|
||||
|
||||
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(reshaped_logits, labels)
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
|
||||
class RobertaClassificationHead(nn.Module):
|
||||
"""Head for sentence-level classification tasks."""
|
||||
|
||||
Reference in New Issue
Block a user