update readme - fix SQuAD model on multi-GPU

This commit is contained in:
thomwolf
2018-11-08 21:22:22 +01:00
parent 4850ec5888
commit 2c5d993ba4
2 changed files with 10 additions and 3 deletions

View File

@@ -455,9 +455,11 @@ class BertForQuestionAnswering(nn.Module):
end_logits = end_logits.squeeze(-1)
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension - if not this is a no-op
start_positions = start_positions.squeeze(-1)
end_positions = end_positions.squeeze(-1)
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)