From f3bda2352a2911739b85a1bc8fe65b0a33867c13 Mon Sep 17 00:00:00 2001 From: Thibault Fevry Date: Mon, 4 Feb 2019 11:46:36 -0500 Subject: [PATCH] Only keep the active part mof the loss for token classification --- pytorch_pretrained_bert/modeling.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index faa68ab939..d05ccd36bf 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -1025,7 +1025,14 @@ class BertForTokenClassification(PreTrainedBertModel): if labels is not None: loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels)[active_loss] + active_labels = labels.view(-1)[active_loss] + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return loss else: return logits