Only keep the active part mof the loss for token classification
This commit is contained in:
@@ -1025,7 +1025,14 @@ class BertForTokenClassification(PreTrainedBertModel):
|
|||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
# Only keep active parts of the loss
|
||||||
|
if attention_mask is not None:
|
||||||
|
active_loss = attention_mask.view(-1) == 1
|
||||||
|
active_logits = logits.view(-1, self.num_labels)[active_loss]
|
||||||
|
active_labels = labels.view(-1)[active_loss]
|
||||||
|
loss = loss_fct(active_logits, active_labels)
|
||||||
|
else:
|
||||||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
return loss
|
return loss
|
||||||
else:
|
else:
|
||||||
return logits
|
return logits
|
||||||
|
|||||||
Reference in New Issue
Block a user