fixing target size in crossentropy losses
This commit is contained in:
@@ -678,8 +678,8 @@ class BertForPreTraining(PreTrainedBertModel):
|
|||||||
|
|
||||||
if masked_lm_labels is not None and next_sentence_label is not None:
|
if masked_lm_labels is not None and next_sentence_label is not None:
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
masked_lm_loss = loss_fct(prediction_scores, masked_lm_labels)
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels(-1))
|
||||||
next_sentence_loss = loss_fct(seq_relationship_score, next_sentence_label)
|
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
||||||
total_loss = masked_lm_loss + next_sentence_loss
|
total_loss = masked_lm_loss + next_sentence_loss
|
||||||
return total_loss
|
return total_loss
|
||||||
else:
|
else:
|
||||||
@@ -741,7 +741,7 @@ class BertForMaskedLM(PreTrainedBertModel):
|
|||||||
|
|
||||||
if masked_lm_labels is not None:
|
if masked_lm_labels is not None:
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
masked_lm_loss = loss_fct(prediction_scores, masked_lm_labels)
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
||||||
return masked_lm_loss
|
return masked_lm_loss
|
||||||
else:
|
else:
|
||||||
return prediction_scores
|
return prediction_scores
|
||||||
@@ -803,7 +803,7 @@ class BertForNextSentencePrediction(PreTrainedBertModel):
|
|||||||
|
|
||||||
if next_sentence_label is not None:
|
if next_sentence_label is not None:
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
next_sentence_loss = loss_fct(seq_relationship_score, next_sentence_label)
|
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
||||||
return next_sentence_loss
|
return next_sentence_loss
|
||||||
else:
|
else:
|
||||||
return seq_relationship_score
|
return seq_relationship_score
|
||||||
@@ -856,6 +856,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, config, num_labels=2):
|
def __init__(self, config, num_labels=2):
|
||||||
super(BertForSequenceClassification, self).__init__(config)
|
super(BertForSequenceClassification, self).__init__(config)
|
||||||
|
self.num_labels = num_labels
|
||||||
self.bert = BertModel(config)
|
self.bert = BertModel(config)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
||||||
@@ -868,7 +869,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
|
|||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits, labels)
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
return loss, logits
|
return loss, logits
|
||||||
else:
|
else:
|
||||||
return logits
|
return logits
|
||||||
|
|||||||
Reference in New Issue
Block a user