From d82e5deeb126480b47f4368141c9fc7c0d733d45 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 18 Jun 2019 12:13:14 +0200 Subject: [PATCH] set find_unused_parameters=True in DDP --- README.md | 7 ++++--- examples/run_squad.py | 5 ++++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index b4c0ab9f01..cd6abcacdd 100644 --- a/README.md +++ b/README.md @@ -1468,12 +1468,13 @@ python -m torch.distributed.launch --nproc_per_node=8 \ --do_lower_case \ --train_file $SQUAD_DIR/train-v1.1.json \ --predict_file $SQUAD_DIR/dev-v1.1.json \ - --train_batch_size 12 \ --learning_rate 3e-5 \ - --num_train_epochs 2.0 \ + --num_train_epochs 2 \ --max_seq_length 384 \ --doc_stride 128 \ - --output_dir /tmp/debug_squad/ + --output_dir /tmp/debug_squad/ \ + --train_batch_size 24 \ + --gradient_accumulation_steps 2 ``` ## Notebooks diff --git a/examples/run_squad.py b/examples/run_squad.py index 6378f443e4..313cb453af 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -907,7 +907,10 @@ def main(): # except ImportError: # raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) + model = torch.nn.parallel.DistributedDataParallel(model, + device_ids=[args.local_rank], + output_device=args.local_rank, + find_unused_parameters=True) elif n_gpu > 1: model = torch.nn.DataParallel(model)