From 4d8c4337ae384262b48eed646f8586704e1bc530 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 18 Jun 2019 22:41:28 +0200 Subject: [PATCH] test barrier in distrib training --- README.md | 23 ++++++++--------------- examples/run_classifier.py | 13 +++++++++++-- pytorch_pretrained_bert/modeling.py | 4 ++++ 3 files changed, 23 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 7d7fea4182..b0a155f140 100644 --- a/README.md +++ b/README.md @@ -1272,27 +1272,20 @@ python run_classifier.py \ ``` **Distributed training** -Here is an example using distributed training on 8 V100 GPUs and Bert Whole Word Masking model to reach a F1 > 93 on SQuAD: +Here is an example using distributed training on 8 V100 GPUs and Bert Whole Word Masking model to reach a F1 > 92 on MRPC: ```bash -python -m torch.distributed.launch --nproc_per_node=8 \ - run_classifier.py \ - --bert_model bert-large-cased-whole-word-masking \ - --task_name MRPC \ - --do_train \ - --do_eval \ - --do_lower_case \ - --data_dir $GLUE_DIR/MRPC/ \ - --max_seq_length 128 \ - --train_batch_size 64 \ - --learning_rate 2e-5 \ - --num_train_epochs 3.0 \ - --output_dir /tmp/mrpc_output/ +python -m torch.distributed.launch --nproc_per_node 8 run_classifier.py --bert_model bert-large-uncased-whole-word-masking --task_name MRPC --do_train --do_eval --do_lower_case --data_dir $GLUE_DIR/MRPC/ --max_seq_length 128 --train_batch_size 8 --learning_rate 2e-5 --num_train_epochs 3.0 --output_dir /tmp/mrpc_output/ ``` Training with these hyper-parameters gave us the following results: ```bash -{"exact_match": 86.91579943235573, "f1": 93.1532499015869} + acc = 0.8823529411764706 + acc_and_f1 = 0.901702786377709 + eval_loss = 0.3418912578906332 + f1 = 0.9210526315789473 + global_step = 174 + loss = 0.07231863956341798 ``` #### SQuAD diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 47cf43e17c..123efb9147 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -50,6 +50,12 @@ else: logger = logging.getLogger(__name__) +def barrier(): + t = torch.randn((), device='cuda') + torch.distributed.all_reduce(t) + torch.cuda.synchronize() + + def main(): parser = argparse.ArgumentParser() @@ -201,10 +207,13 @@ def main(): label_list = processor.get_labels() num_labels = len(label_list) + if args.local_rank not in [-1, 0]: + barrier() # Make sure only the first process in distributed training will download model & vocab tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) - - # Prepare model model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels) + if args.local_rank == 0: + barrier() + if args.fp16: model.half() model.to(device) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 5074240685..4dfffb8e43 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -44,6 +44,10 @@ PRETRAINED_MODEL_ARCHIVE_MAP = { 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin", 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin", 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin", + 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin", + 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin", + 'bert-base-uncased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-finetuned-mrpc-pytorch_model.bin", + 'bert-large-uncased-whole-word-masking-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-mrpc-pytorch_model.bin", } PRETRAINED_CONFIG_ARCHIVE_MAP = { 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",