Merge branch 'master' into python_2

This commit is contained in:
Thomas Wolf
2019-02-06 00:13:20 +01:00
committed by GitHub
8 changed files with 271 additions and 175 deletions

View File

@@ -1067,7 +1067,7 @@ class BertForTokenClassification(BertPreTrainedModel):
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [0, ..., num_labels].
Outputs:
@@ -1107,7 +1107,14 @@ class BertForTokenClassification(BertPreTrainedModel):
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss]
active_labels = labels.view(-1)[active_loss]
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss
else:
return logits