update readme - fix SQuAD model on multi-GPU
This commit is contained in:
@@ -194,3 +194,8 @@ python run_squad.py \
|
|||||||
--doc_stride 128 \
|
--doc_stride 128 \
|
||||||
--output_dir ../debug_squad/
|
--output_dir ../debug_squad/
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Training with the previous hyper-parameters and a batch size 32 (on 4 GPUs) for 2 epochs gave us the following results:
|
||||||
|
```bash
|
||||||
|
{"f1": 88.19829549714827, "exact_match": 80.75685903500474}
|
||||||
|
```
|
||||||
|
|||||||
@@ -455,9 +455,11 @@ class BertForQuestionAnswering(nn.Module):
|
|||||||
end_logits = end_logits.squeeze(-1)
|
end_logits = end_logits.squeeze(-1)
|
||||||
|
|
||||||
if start_positions is not None and end_positions is not None:
|
if start_positions is not None and end_positions is not None:
|
||||||
# If we are on multi-GPU, split add a dimension - if not this is a no-op
|
# If we are on multi-GPU, split add a dimension
|
||||||
start_positions = start_positions.squeeze(-1)
|
if len(start_positions.size()) > 1:
|
||||||
end_positions = end_positions.squeeze(-1)
|
start_positions = start_positions.squeeze(-1)
|
||||||
|
if len(end_positions.size()) > 1:
|
||||||
|
end_positions = end_positions.squeeze(-1)
|
||||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
ignored_index = start_logits.size(1)
|
ignored_index = start_logits.size(1)
|
||||||
start_positions.clamp_(0, ignored_index)
|
start_positions.clamp_(0, ignored_index)
|
||||||
|
|||||||
Reference in New Issue
Block a user