From 80f995a141e496a5ff9d7996057e835f24371cd4 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 24 Apr 2019 16:51:54 +0200 Subject: [PATCH] revert BertForMultipleChoice linear classifier --- pytorch_pretrained_bert/modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 374a57c34f..6b71f007c3 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -1034,7 +1034,7 @@ class BertForMultipleChoice(BertPreTrainedModel): self.num_choices = num_choices self.bert = BertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, num_choices) + self.classifier = nn.Linear(config.hidden_size, 1) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):