From 472857c47f3b6a142a7aaa53836e33cd8543088d Mon Sep 17 00:00:00 2001 From: Catalin Voss Date: Sun, 24 Mar 2019 13:49:42 -0700 Subject: [PATCH] Fix typo syntax err (sorry, c/p from my repo) --- pytorch_pretrained_bert/modeling_gpt2.py | 2 +- pytorch_pretrained_bert/modeling_openai.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_pretrained_bert/modeling_gpt2.py b/pytorch_pretrained_bert/modeling_gpt2.py index 1733a5b3f4..15e7ca26e1 100644 --- a/pytorch_pretrained_bert/modeling_gpt2.py +++ b/pytorch_pretrained_bert/modeling_gpt2.py @@ -625,7 +625,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): # in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}] # We just flatten the tokens out this way. loss_fct = CrossEntropyLoss(ignore_index=-1) - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return loss return lm_logits, presents diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index 9c708f88a2..ab4107667b 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -724,7 +724,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): # in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}] # We just flatten the tokens out this way. loss_fct = CrossEntropyLoss(ignore_index=-1) - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return loss return lm_logits