From 2e6f5ffb96029398f740b6eacdc86b117cccb86b Mon Sep 17 00:00:00 2001 From: Catalin Voss Date: Sun, 24 Mar 2019 13:36:46 -0700 Subject: [PATCH] Fix GPT language model loss here as well --- pytorch_pretrained_bert/modeling_openai.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index 296abbfc31..9c708f88a2 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -716,8 +716,16 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): hidden_states = self.transformer(input_ids, position_ids, token_type_ids) lm_logits = self.lm_head(hidden_states) if lm_labels is not None: + # Shift so that tokens < n predict n + shift_logits = lm_logits[:, :-1] + shift_labels = lm_labels[:, 1:] + + # In tensorflow, it's [batch, d_0, d_1, ..., d_{r-1}, num_classes] + # in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}] + # We just flatten the tokens out this way. loss_fct = CrossEntropyLoss(ignore_index=-1) - loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)) + shift_labels.view(-1)) return loss return lm_logits