From c3527cfbc4ebce38b22e0419d4ca698f5b065688 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 5 Nov 2018 14:18:48 +0100 Subject: [PATCH] ignore SQuAD targets outside of seq_length --- modeling.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/modeling.py b/modeling.py index 6b9cddd569..c467e8266e 100644 --- a/modeling.py +++ b/modeling.py @@ -455,9 +455,15 @@ class BertForQuestionAnswering(nn.Module): end_logits = end_logits.squeeze(-1) if start_positions is not None and end_positions is not None: - start_positions = start_positions.squeeze(-1) # If we are on multi-GPU, split add a dimension + # If we are on multi-GPU, split add a dimension - if not this is a no-op + start_positions = start_positions.squeeze(-1) end_positions = end_positions.squeeze(-1) - loss_fct = CrossEntropyLoss() + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + 1 + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2