From 3ebf1a13c9db61c32b2d589a8823ef30485f0304 Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Fri, 2 Nov 2018 17:49:35 -0400 Subject: [PATCH] Fix loss computation for indexes bigger than max_seq_length. --- modeling_pytorch.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/modeling_pytorch.py b/modeling_pytorch.py index 4a8514e3a0..b227dfeb91 100644 --- a/modeling_pytorch.py +++ b/modeling_pytorch.py @@ -485,9 +485,22 @@ class BertForQuestionAnswering(nn.Module): start_logits, end_logits = logits.split(1, dim=-1) if start_positions is not None and end_positions is not None: - loss_fct = CrossEntropyLoss() - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) + #loss_fct = CrossEntropyLoss() + #start_loss = loss_fct(start_logits, start_positions) + #end_loss = loss_fct(end_logits, end_positions) + batch_size, seq_length = input_ids.size() + + def compute_loss(logits, positions): + max_position = positions.max().item() + one_hot = torch.FloatTensor(batch_size, max(max_position, seq_length) +1).zero_() + one_hot = one_hot.scatter(1, positions, 1) + one_hot = one_hot[:, :seq_length] + log_probs = nn.functional.log_softmax(logits, dim = -1).view(batch_size, seq_length) + loss = -torch.mean(torch.sum(one_hot*log_probs), dim = -1) + return loss + + start_loss = compute_loss(start_logits, start_positions) + end_loss = compute_loss(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 return total_loss, (start_logits, end_logits) else: