cache in run_classifier + various fixes to the examples

This commit is contained in:
thomwolf
2019-06-18 15:58:22 +02:00
parent e6e5f19257
commit 15ebd67d4e
5 changed files with 665 additions and 624 deletions

View File

@@ -541,6 +541,7 @@ where
- `bert-base-german-cased`: Trained on German data only, 12-layer, 768-hidden, 12-heads, 110M parameters [Performance Evaluation](https://deepset.ai/german-bert)
- `bert-large-uncased-whole-word-masking`: 24-layer, 1024-hidden, 16-heads, 340M parameters - Trained with Whole Word Masking (mask all of the the tokens corresponding to a word at once)
- `bert-large-cased-whole-word-masking`: 24-layer, 1024-hidden, 16-heads, 340M parameters - Trained with Whole Word Masking (mask all of the the tokens corresponding to a word at once)
- `bert-large-uncased-whole-word-masking-finetuned-squad`: The `bert-large-uncased-whole-word-masking` model finetuned on SQuAD (using the `run_squad.py` examples). Results: *exact_match: 86.91579943235573, f1: 93.1532499015869*
- `openai-gpt`: OpenAI GPT English model, 12-layer, 768-hidden, 12-heads, 110M parameters
- `gpt2`: OpenAI GPT-2 English model, 12-layer, 768-hidden, 12-heads, 117M parameters
- `gpt2-medium`: OpenAI GPT-2 English model, 24-layer, 1024-hidden, 16-heads, 345M parameters
@@ -608,13 +609,15 @@ There are three types of files you need to save to be able to reload a fine-tune
- the configuration file of the model which is saved as a JSON file, and
- the vocabulary (and the merges for the BPE-based models GPT and GPT-2).
The defaults files names of these files are as follow:
The *default filenames* of these files are as follow:
- the model weights file: `pytorch_model.bin`,
- the configuration file: `config.json`,
- the vocabulary file: `vocab.txt` for BERT and Transformer-XL, `vocab.json` for GPT/GPT-2 (BPE vocabulary),
- for GPT/GPT-2 (BPE vocabulary) the additional merges file: `merges.txt`.
**If you save a model using these *default filenames*, you can then re-load the model and tokenizer using the `from_pretrained()` method.**
Here is the recommended way of saving the model, configuration and vocabulary to an `output_dir` directory and reloading the model and tokenizer afterwards:
```python
@@ -1268,6 +1271,30 @@ python run_classifier.py \
--fp16
```
**Distributed training**
Here is an example using distributed training on 8 V100 GPUs and Bert Whole Word Masking model to reach a F1 > 93 on SQuAD:
```bash
python -m torch.distributed.launch --nproc_per_node=8 \
run_classifier.py \
--bert_model bert-large-cased-whole-word-masking \
--task_name MRPC \
--do_train \
--do_eval \
--do_lower_case \
--data_dir $GLUE_DIR/MRPC/ \
--max_seq_length 128 \
--train_batch_size 64 \
--learning_rate 2e-5 \
--num_train_epochs 3.0 \
--output_dir /tmp/mrpc_output/
```
Training with these hyper-parameters gave us the following results:
```bash
{"exact_match": 86.91579943235573, "f1": 93.1532499015869}
```
#### SQuAD
This example code fine-tunes BERT on the SQuAD dataset. It runs in 24 min (with BERT-base) or 68 min (with BERT-large) on a single tesla V100 16GB.
@@ -1298,9 +1325,36 @@ python run_squad.py \
Training with the previous hyper-parameters gave us the following results:
```bash
python $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json /tmp/debug_squad/predictions.json
{"f1": 88.52381567990474, "exact_match": 81.22043519394512}
```
Here is an example using distributed training on 8 V100 GPUs and Bert Whole Word Masking model to reach a F1 > 93 on SQuAD:
```bash
python -m torch.distributed.launch --nproc_per_node=8 \
run_squad.py \
--bert_model bert-large-cased-whole-word-masking \
--do_train \
--do_predict \
--do_lower_case \
--train_file $SQUAD_DIR/train-v1.1.json \
--predict_file $SQUAD_DIR/dev-v1.1.json \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir ../models/train_squad_large_cased_wwm/ \
--train_batch_size 24 \
--gradient_accumulation_steps 12
```
Training with these hyper-parameters gave us the following results:
```bash
python $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json ../models/train_squad_large_cased_wwm/predictions.json
{"exact_match": 86.91579943235573, "f1": 93.1532499015869}
```
#### SWAG
The data for SWAG can be downloaded by cloning the following [repository](https://github.com/rowanz/swagaf)