From 870d71636e88310eab4bd3b4459783f556a29a4a Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 26 Nov 2018 09:51:34 +0100 Subject: [PATCH] fixing target size in crossentropy losses --- pytorch_pretrained_bert/modeling.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 9ef592d0dc..5dd1690313 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -678,8 +678,8 @@ class BertForPreTraining(PreTrainedBertModel): if masked_lm_labels is not None and next_sentence_label is not None: loss_fct = CrossEntropyLoss(ignore_index=-1) - masked_lm_loss = loss_fct(prediction_scores, masked_lm_labels) - next_sentence_loss = loss_fct(seq_relationship_score, next_sentence_label) + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) total_loss = masked_lm_loss + next_sentence_loss return total_loss else: @@ -741,7 +741,7 @@ class BertForMaskedLM(PreTrainedBertModel): if masked_lm_labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-1) - masked_lm_loss = loss_fct(prediction_scores, masked_lm_labels) + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) return masked_lm_loss else: return prediction_scores @@ -803,7 +803,7 @@ class BertForNextSentencePrediction(PreTrainedBertModel): if next_sentence_label is not None: loss_fct = CrossEntropyLoss(ignore_index=-1) - next_sentence_loss = loss_fct(seq_relationship_score, next_sentence_label) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) return next_sentence_loss else: return seq_relationship_score @@ -856,6 +856,7 @@ class BertForSequenceClassification(PreTrainedBertModel): """ def __init__(self, config, num_labels=2): super(BertForSequenceClassification, self).__init__(config) + self.num_labels = num_labels self.bert = BertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, num_labels) @@ -868,7 +869,7 @@ class BertForSequenceClassification(PreTrainedBertModel): if labels is not None: loss_fct = CrossEntropyLoss() - loss = loss_fct(logits, labels) + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return loss, logits else: return logits