ignore SQuAD targets outside of seq_length

This commit is contained in:
thomwolf
2018-11-05 14:18:48 +01:00
parent 1b99cdf71b
commit c3527cfbc4

View File

@@ -455,9 +455,15 @@ class BertForQuestionAnswering(nn.Module):
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1)
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
start_positions = start_positions.squeeze(-1) # If we are on multi-GPU, split add a dimension # If we are on multi-GPU, split add a dimension - if not this is a no-op
start_positions = start_positions.squeeze(-1)
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
loss_fct = CrossEntropyLoss() # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) + 1
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions) end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2 total_loss = (start_loss + end_loss) / 2