fix consistency CrossEntropyLoss in modeling_bart (#6265)
This commit is contained in:
@@ -1040,7 +1040,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
|
|
||||||
masked_lm_loss = None
|
masked_lm_loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# TODO(SS): do we need to ignore pad tokens in labels?
|
# TODO(SS): do we need to ignore pad tokens in labels?
|
||||||
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
|
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
|
||||||
|
|
||||||
@@ -1179,7 +1179,8 @@ class BartForSequenceClassification(PretrainedBartModel):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss = F.cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
|||||||
Reference in New Issue
Block a user