From 38ba7b439bcdadc73db03bfd7504fae44f74ab93 Mon Sep 17 00:00:00 2001 From: David Pollack Date: Mon, 15 Apr 2019 10:38:01 +0200 Subject: [PATCH] fixed BertForMultipleChoice model init and forward pass --- pytorch_pretrained_bert/modeling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 2736e34d7f..374a57c34f 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -1034,13 +1034,13 @@ class BertForMultipleChoice(BertPreTrainedModel): self.num_choices = num_choices self.bert = BertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, 1) + self.classifier = nn.Linear(config.hidden_size, num_choices) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): flat_input_ids = input_ids.view(-1, input_ids.size(-1)) - flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) - flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output)