From 05053d163cbd6021e47487699e4e2de36c3b7720 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 26 Nov 2018 10:45:13 +0100 Subject: [PATCH] update cache_dir in readme and examples --- README.md | 6 +++--- examples/run_classifier.py | 3 ++- examples/run_squad.py | 3 ++- pytorch_pretrained_bert/__init__.py | 1 + 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 3d1e66cbbf..bf54538531 100644 --- a/README.md +++ b/README.md @@ -162,13 +162,12 @@ Here is a detailed documentation of the classes in the package and how to use th To load one of Google AI's pre-trained models or a PyTorch saved model (an instance of `BertForPreTraining` saved with `torch.save()`), the PyTorch model classes and the tokenizer can be instantiated as ```python -model = BERT_CLASS.from_pretrain(PRE_TRAINED_MODEL_NAME_OR_PATH) +model = BERT_CLASS.from_pretrain(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None) ``` where - `BERT_CLASS` is either the `BertTokenizer` class (to load the vocabulary) or one of the six PyTorch model classes (to load the pre-trained weights): `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification` or `BertForQuestionAnswering`, and - - `PRE_TRAINED_MODEL_NAME_OR_PATH` is either: - the shortcut name of a Google AI's pre-trained model selected in the list: @@ -184,7 +183,8 @@ where - `bert_config.json` a configuration file for the model, and - `pytorch_model.bin` a PyTorch dump of a pre-trained instance `BertForPreTraining` (saved with the usual `torch.save()`) -If `PRE_TRAINED_MODEL_NAME_OR_PATH` is a shortcut name, the pre-trained weights will be downloaded from AWS S3 (see the links [here](pytorch_pretrained_bert/modeling.py)) and stored in a cache folder to avoid future download (the cache folder can be found at `~/.pytorch_pretrained_bert/`). + If `PRE_TRAINED_MODEL_NAME_OR_PATH` is a shortcut name, the pre-trained weights will be downloaded from AWS S3 (see the links [here](pytorch_pretrained_bert/modeling.py)) and stored in a cache folder to avoid future download (the cache folder can be found at `~/.pytorch_pretrained_bert/`). +- `cache_dir` can be an optional path to a specific directory to download and cache the pre-trained model weights. This option is useful in particular when you are using distributed training: to avoid concurrent access to the same weights you can set for example `cache_dir='./pretrained_model_{}'.format(args.local_rank)` (see the section on distributed training for more information) Example: ```python diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 1626fc8164..5ceab4ae26 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -482,7 +482,8 @@ def main(): len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) # Prepare model - model = BertForSequenceClassification.from_pretrained(args.bert_model, len(label_list)) + model = BertForSequenceClassification.from_pretrained(args.bert_model, len(label_list), + cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank)) if args.fp16: model.half() model.to(device) diff --git a/examples/run_squad.py b/examples/run_squad.py index 7f58ddc35a..c13362b94e 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -821,7 +821,8 @@ def main(): len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) # Prepare model - model = BertForQuestionAnswering.from_pretrained(args.bert_model) + model = BertForQuestionAnswering.from_pretrained(args.bert_model, + cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank)) if args.fp16: model.half() model.to(device) diff --git a/pytorch_pretrained_bert/__init__.py b/pytorch_pretrained_bert/__init__.py index e12e48c2b9..7850fa5555 100644 --- a/pytorch_pretrained_bert/__init__.py +++ b/pytorch_pretrained_bert/__init__.py @@ -3,3 +3,4 @@ from .modeling import (BertConfig, BertModel, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction, BertForSequenceClassification, BertForQuestionAnswering) from .optimization import BertAdam +from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE