diff --git a/modeling.py b/modeling.py index 6b9cddd569..c467e8266e 100644 --- a/modeling.py +++ b/modeling.py @@ -455,9 +455,15 @@ class BertForQuestionAnswering(nn.Module): end_logits = end_logits.squeeze(-1) if start_positions is not None and end_positions is not None: - start_positions = start_positions.squeeze(-1) # If we are on multi-GPU, split add a dimension + # If we are on multi-GPU, split add a dimension - if not this is a no-op + start_positions = start_positions.squeeze(-1) end_positions = end_positions.squeeze(-1) - loss_fct = CrossEntropyLoss() + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + 1 + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2