BartForSequenceClassification: fix num_labels, add test (#3110)

This commit is contained in:
Sam Shleifer
2020-03-03 15:54:29 -05:00
committed by GitHub
parent f631e01d2c
commit e9e6efdc45
2 changed files with 16 additions and 6 deletions

View File

@@ -1324,7 +1324,7 @@ class BartForSequenceClassification(PretrainedBartModel):
# Prepend logits
outputs = (logits,) + outputs[1:] # Add hidden states and attention if they are here
if labels is not None: # prepend loss to output,
loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
loss = F.cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs