test barrier in distrib training
This commit is contained in:
23
README.md
23
README.md
@@ -1272,27 +1272,20 @@ python run_classifier.py \
|
|||||||
```
|
```
|
||||||
|
|
||||||
**Distributed training**
|
**Distributed training**
|
||||||
Here is an example using distributed training on 8 V100 GPUs and Bert Whole Word Masking model to reach a F1 > 93 on SQuAD:
|
Here is an example using distributed training on 8 V100 GPUs and Bert Whole Word Masking model to reach a F1 > 92 on MRPC:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m torch.distributed.launch --nproc_per_node=8 \
|
python -m torch.distributed.launch --nproc_per_node 8 run_classifier.py --bert_model bert-large-uncased-whole-word-masking --task_name MRPC --do_train --do_eval --do_lower_case --data_dir $GLUE_DIR/MRPC/ --max_seq_length 128 --train_batch_size 8 --learning_rate 2e-5 --num_train_epochs 3.0 --output_dir /tmp/mrpc_output/
|
||||||
run_classifier.py \
|
|
||||||
--bert_model bert-large-cased-whole-word-masking \
|
|
||||||
--task_name MRPC \
|
|
||||||
--do_train \
|
|
||||||
--do_eval \
|
|
||||||
--do_lower_case \
|
|
||||||
--data_dir $GLUE_DIR/MRPC/ \
|
|
||||||
--max_seq_length 128 \
|
|
||||||
--train_batch_size 64 \
|
|
||||||
--learning_rate 2e-5 \
|
|
||||||
--num_train_epochs 3.0 \
|
|
||||||
--output_dir /tmp/mrpc_output/
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Training with these hyper-parameters gave us the following results:
|
Training with these hyper-parameters gave us the following results:
|
||||||
```bash
|
```bash
|
||||||
{"exact_match": 86.91579943235573, "f1": 93.1532499015869}
|
acc = 0.8823529411764706
|
||||||
|
acc_and_f1 = 0.901702786377709
|
||||||
|
eval_loss = 0.3418912578906332
|
||||||
|
f1 = 0.9210526315789473
|
||||||
|
global_step = 174
|
||||||
|
loss = 0.07231863956341798
|
||||||
```
|
```
|
||||||
|
|
||||||
#### SQuAD
|
#### SQuAD
|
||||||
|
|||||||
@@ -50,6 +50,12 @@ else:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def barrier():
|
||||||
|
t = torch.randn((), device='cuda')
|
||||||
|
torch.distributed.all_reduce(t)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
@@ -201,10 +207,13 @@ def main():
|
|||||||
label_list = processor.get_labels()
|
label_list = processor.get_labels()
|
||||||
num_labels = len(label_list)
|
num_labels = len(label_list)
|
||||||
|
|
||||||
|
if args.local_rank not in [-1, 0]:
|
||||||
|
barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||||
|
|
||||||
# Prepare model
|
|
||||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
|
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
|
||||||
|
if args.local_rank == 0:
|
||||||
|
barrier()
|
||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
model.half()
|
model.half()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|||||||
@@ -44,6 +44,10 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|||||||
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
|
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
|
||||||
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
|
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
|
||||||
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
|
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
|
||||||
|
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
|
||||||
|
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
|
||||||
|
'bert-base-uncased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-finetuned-mrpc-pytorch_model.bin",
|
||||||
|
'bert-large-uncased-whole-word-masking-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-mrpc-pytorch_model.bin",
|
||||||
}
|
}
|
||||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
|
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
|
||||||
|
|||||||
Reference in New Issue
Block a user