update double head model
This commit is contained in:
@@ -371,7 +371,7 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
|
|||||||
def forward(self, hidden_states, mc_token_ids):
|
def forward(self, hidden_states, mc_token_ids):
|
||||||
# Classification logits
|
# Classification logits
|
||||||
# hidden_state (bsz, num_choices, seq_length, hidden_size)
|
# hidden_state (bsz, num_choices, seq_length, hidden_size)
|
||||||
# mc_token_ids (bsz, num_choices, 1)
|
# mc_token_ids (bsz, num_choices)
|
||||||
mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1))
|
mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1))
|
||||||
# (bsz, num_choices, 1, hidden_size)
|
# (bsz, num_choices, 1, hidden_size)
|
||||||
multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
|
multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
|
||||||
|
|||||||
Reference in New Issue
Block a user