From b509bf765574852648020d60690386b80e970cf6 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 12 Apr 2019 12:12:33 +0200 Subject: [PATCH] updating loss computation --- pytorch_pretrained_bert/modeling_openai.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index be4f959485..c4d20c331e 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -716,9 +716,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): lm_logits = self.lm_head(hidden_states) if lm_labels is not None: # Shift so that tokens < n predict n - shift_logits = lm_logits[:, :-1].contiguous() - shift_labels = lm_labels[:, 1:].contiguous() - + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = lm_labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss(ignore_index=-1) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), @@ -808,11 +807,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) losses = [] if lm_labels is not None: - shift_logits = lm_logits[:, :-1].contiguous() - shift_labels = lm_labels[:, 1:].contiguous() + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = lm_labels[..., 1:].contiguous() loss_fct = CrossEntropyLoss(ignore_index=-1) - losses.append(loss_fct(shift_logits.view(-1, - shift_logits.size(-1)), shift_labels.view(-1))) + losses.append(loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))) if mc_labels is not None: loss_fct = CrossEntropyLoss() losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))