Merge branch 'master' into python_2
This commit is contained in:
@@ -1067,7 +1067,7 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
||||
input sequence length in the current batch. It's the mask that we typically use for attention when
|
||||
a batch has varying length sentences.
|
||||
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
|
||||
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with indices selected in [0, ..., num_labels].
|
||||
|
||||
Outputs:
|
||||
@@ -1107,7 +1107,14 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
# Only keep active parts of the loss
|
||||
if attention_mask is not None:
|
||||
active_loss = attention_mask.view(-1) == 1
|
||||
active_logits = logits.view(-1, self.num_labels)[active_loss]
|
||||
active_labels = labels.view(-1)[active_loss]
|
||||
loss = loss_fct(active_logits, active_labels)
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
return loss
|
||||
else:
|
||||
return logits
|
||||
|
||||
Reference in New Issue
Block a user