Fix typo syntax err (sorry, c/p from my repo)
This commit is contained in:
@@ -625,7 +625,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
# in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}]
|
# in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}]
|
||||||
# We just flatten the tokens out this way.
|
# We just flatten the tokens out this way.
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
||||||
shift_labels.view(-1))
|
shift_labels.view(-1))
|
||||||
return loss
|
return loss
|
||||||
return lm_logits, presents
|
return lm_logits, presents
|
||||||
|
|||||||
@@ -724,7 +724,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
|||||||
# in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}]
|
# in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}]
|
||||||
# We just flatten the tokens out this way.
|
# We just flatten the tokens out this way.
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
||||||
shift_labels.view(-1))
|
shift_labels.view(-1))
|
||||||
return loss
|
return loss
|
||||||
return lm_logits
|
return lm_logits
|
||||||
|
|||||||
Reference in New Issue
Block a user