fixing target size in crossentropy losses
This commit is contained in:
@@ -678,8 +678,8 @@ class BertForPreTraining(PreTrainedBertModel):
|
||||
|
||||
if masked_lm_labels is not None and next_sentence_label is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
masked_lm_loss = loss_fct(prediction_scores, masked_lm_labels)
|
||||
next_sentence_loss = loss_fct(seq_relationship_score, next_sentence_label)
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels(-1))
|
||||
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
||||
total_loss = masked_lm_loss + next_sentence_loss
|
||||
return total_loss
|
||||
else:
|
||||
@@ -741,7 +741,7 @@ class BertForMaskedLM(PreTrainedBertModel):
|
||||
|
||||
if masked_lm_labels is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
masked_lm_loss = loss_fct(prediction_scores, masked_lm_labels)
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
||||
return masked_lm_loss
|
||||
else:
|
||||
return prediction_scores
|
||||
@@ -803,7 +803,7 @@ class BertForNextSentencePrediction(PreTrainedBertModel):
|
||||
|
||||
if next_sentence_label is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
next_sentence_loss = loss_fct(seq_relationship_score, next_sentence_label)
|
||||
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
||||
return next_sentence_loss
|
||||
else:
|
||||
return seq_relationship_score
|
||||
@@ -856,6 +856,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
|
||||
"""
|
||||
def __init__(self, config, num_labels=2):
|
||||
super(BertForSequenceClassification, self).__init__(config)
|
||||
self.num_labels = num_labels
|
||||
self.bert = BertModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
||||
@@ -868,7 +869,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
|
||||
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
return loss, logits
|
||||
else:
|
||||
return logits
|
||||
|
||||
Reference in New Issue
Block a user