From 1d203a34c06fb8b2c1de856d58950f9d193cc1fc Mon Sep 17 00:00:00 2001 From: thomwolf Date: Thu, 11 Apr 2019 23:51:03 +0200 Subject: [PATCH] back to simple indexing --- pytorch_pretrained_bert/modeling_openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index 1a2a3feb20..be4f959485 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -372,7 +372,7 @@ class OpenAIGPTMultipleChoiceHead(nn.Module): # Classification logits # hidden_state (bsz, num_choices, seq_length, hidden_size) # mc_token_ids (bsz, num_choices, 1) - mc_token_ids = mc_token_ids.unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1)) + mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1)) # (bsz, num_choices, 1, hidden_size) multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2) # (bsz, num_choices, hidden_size)