BartForSequenceClassification: fix num_labels, add test (#3110)
This commit is contained in:
@@ -1324,7 +1324,7 @@ class BartForSequenceClassification(PretrainedBartModel):
|
||||
# Prepend logits
|
||||
outputs = (logits,) + outputs[1:] # Add hidden states and attention if they are here
|
||||
if labels is not None: # prepend loss to output,
|
||||
loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
loss = F.cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs
|
||||
|
||||
Reference in New Issue
Block a user