added LM head for OpenAI
This commit is contained in:
@@ -267,11 +267,11 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
|
||||
nn.init.normal_(self.linear.weight, std = 0.02)
|
||||
nn.init.normal_(self.linear.bias, 0)
|
||||
|
||||
def forward(self, hidden_states, classification_token_mask):
|
||||
def forward(self, hidden_states, multiple_choice_token_mask):
|
||||
# Classification logits
|
||||
# hidden_states = hidden_states.view(-1, self.n_embd)
|
||||
# classification_token_mask = classification_token_mask.view(-1, 1).expand_as(hidden_states)
|
||||
multiple_choice_h = hidden_states * classification_token_mask.unsqueeze(-1)
|
||||
# multiple_choice_token_mask = multiple_choice_token_mask.view(-1, 1).expand_as(hidden_states)
|
||||
multiple_choice_h = hidden_states * multiple_choice_token_mask.unsqueeze(-1)
|
||||
multiple_choice_h = multiple_choice_h.sum(dim=-2)
|
||||
# flat = x[..., 0].contiguous().view(-1)
|
||||
# multiple_choice_h = multiple_choice_h[flat == self.multiple_choice_token, :]
|
||||
@@ -496,8 +496,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
if lm_labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(lm_logits, lm_labels)
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
|
||||
return loss
|
||||
return lm_logits
|
||||
|
||||
@@ -515,15 +515,14 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||
self.lm_head.set_embeddings_weights(self.transformer.embed.weight)
|
||||
|
||||
def forward(self, input_ids, classification_token_mask, position_ids=None, token_type_ids=None,
|
||||
def forward(self, input_ids, multiple_choice_token_mask, position_ids=None, token_type_ids=None,
|
||||
lm_labels=None, multiple_choice_labels=None):
|
||||
"""
|
||||
input_ids as to be of shape B x C x S
|
||||
""" input_ids should be of shape B x C x S
|
||||
lm_labels can be masked using the -1 value
|
||||
"""
|
||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
multiple_choice_logits = self.multiple_choice_head(hidden_states, classification_token_mask)
|
||||
multiple_choice_logits = self.multiple_choice_head(hidden_states, multiple_choice_token_mask)
|
||||
losses = []
|
||||
if lm_labels is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
|
||||
Reference in New Issue
Block a user