Merge pull request #488 from dhpollack/fix_multichoice
fixed BertForMultipleChoice model init and forward pass
This commit is contained in:
@@ -1052,8 +1052,8 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
|||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
|
||||||
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
|
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||||
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1))
|
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||||
_, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
|
_, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
|
||||||
pooled_output = self.dropout(pooled_output)
|
pooled_output = self.dropout(pooled_output)
|
||||||
logits = self.classifier(pooled_output)
|
logits = self.classifier(pooled_output)
|
||||||
|
|||||||
Reference in New Issue
Block a user