set find_unused_parameters=True in DDP
This commit is contained in:
@@ -1468,12 +1468,13 @@ python -m torch.distributed.launch --nproc_per_node=8 \
|
|||||||
--do_lower_case \
|
--do_lower_case \
|
||||||
--train_file $SQUAD_DIR/train-v1.1.json \
|
--train_file $SQUAD_DIR/train-v1.1.json \
|
||||||
--predict_file $SQUAD_DIR/dev-v1.1.json \
|
--predict_file $SQUAD_DIR/dev-v1.1.json \
|
||||||
--train_batch_size 12 \
|
|
||||||
--learning_rate 3e-5 \
|
--learning_rate 3e-5 \
|
||||||
--num_train_epochs 2.0 \
|
--num_train_epochs 2 \
|
||||||
--max_seq_length 384 \
|
--max_seq_length 384 \
|
||||||
--doc_stride 128 \
|
--doc_stride 128 \
|
||||||
--output_dir /tmp/debug_squad/
|
--output_dir /tmp/debug_squad/ \
|
||||||
|
--train_batch_size 24 \
|
||||||
|
--gradient_accumulation_steps 2
|
||||||
```
|
```
|
||||||
|
|
||||||
## Notebooks
|
## Notebooks
|
||||||
|
|||||||
@@ -907,7 +907,10 @@ def main():
|
|||||||
# except ImportError:
|
# except ImportError:
|
||||||
# raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
# raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||||
|
|
||||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)
|
model = torch.nn.parallel.DistributedDataParallel(model,
|
||||||
|
device_ids=[args.local_rank],
|
||||||
|
output_device=args.local_rank,
|
||||||
|
find_unused_parameters=True)
|
||||||
elif n_gpu > 1:
|
elif n_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user