Fix GPT language model loss here as well
This commit is contained in:
@@ -716,8 +716,16 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
|||||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
|
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
|
||||||
lm_logits = self.lm_head(hidden_states)
|
lm_logits = self.lm_head(hidden_states)
|
||||||
if lm_labels is not None:
|
if lm_labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = lm_logits[:, :-1]
|
||||||
|
shift_labels = lm_labels[:, 1:]
|
||||||
|
|
||||||
|
# In tensorflow, it's [batch, d_0, d_1, ..., d_{r-1}, num_classes]
|
||||||
|
# in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}]
|
||||||
|
# We just flatten the tokens out this way.
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))
|
||||||
|
shift_labels.view(-1))
|
||||||
return loss
|
return loss
|
||||||
return lm_logits
|
return lm_logits
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user