From fe2756ff41147ea6de14d8f81ecc5304382af91d Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 15 Apr 2019 10:04:05 +0200 Subject: [PATCH] update double head model --- pytorch_pretrained_bert/modeling_openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index c4d20c331e..7b95d74f7c 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -371,7 +371,7 @@ class OpenAIGPTMultipleChoiceHead(nn.Module): def forward(self, hidden_states, mc_token_ids): # Classification logits # hidden_state (bsz, num_choices, seq_length, hidden_size) - # mc_token_ids (bsz, num_choices, 1) + # mc_token_ids (bsz, num_choices) mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1)) # (bsz, num_choices, 1, hidden_size) multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)