From ea9dbea9d5b65ca6333e378ea0a8a288399640c2 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 7 May 2019 23:27:18 +0200 Subject: [PATCH] update GPT2 loss computation for more flexbility --- pytorch_pretrained_bert/modeling_gpt2.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pytorch_pretrained_bert/modeling_gpt2.py b/pytorch_pretrained_bert/modeling_gpt2.py index 1c579de83c..ca5a38524a 100644 --- a/pytorch_pretrained_bert/modeling_gpt2.py +++ b/pytorch_pretrained_bert/modeling_gpt2.py @@ -336,6 +336,7 @@ class GPT2MultipleChoiceHead(nn.Module): # (bsz, num_choices, 1, hidden_size) multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2) # (bsz, num_choices, hidden_size) + multiple_choice_h = self.dropout(multiple_choice_h.transpose(1, 2)).transpose(1, 2) multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1) # (bsz, num_choices) return multiple_choice_logits @@ -665,9 +666,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): lm_logits = self.lm_head(hidden_states) if lm_labels is not None: # Shift so that tokens < n predict n - shift_logits = lm_logits[:, :-1].contiguous() - shift_labels = lm_labels[:, 1:].contiguous() - + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = lm_labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss(ignore_index=-1) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), @@ -746,11 +746,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) losses = [] if lm_labels is not None: - shift_logits = lm_logits[:, :-1].contiguous() - shift_labels = lm_labels[:, 1:].contiguous() + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = lm_labels[..., 1:].contiguous() loss_fct = CrossEntropyLoss(ignore_index=-1) - losses.append(loss_fct(shift_logits.view(-1, - shift_logits.size(-1)), shift_labels.view(-1))) + losses.append(loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))) if mc_labels is not None: loss_fct = CrossEntropyLoss() losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))