fix multi-gpu squad loss

This commit is contained in:
thomwolf
2018-11-05 13:46:14 +01:00
parent 955cee33a5
commit 2f4765d3ed
2 changed files with 35 additions and 16 deletions

View File

@@ -455,6 +455,8 @@ class BertForQuestionAnswering(nn.Module):
end_logits = end_logits.squeeze(-1)
if start_positions is not None and end_positions is not None:
start_positions = start_positions.squeeze(-1) # If we are on multi-GPU, split add a dimension
end_positions = end_positions.squeeze(-1)
loss_fct = CrossEntropyLoss()
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)