From 994d86609b480f3643287caa734a50b6b3b5aa4a Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 6 Mar 2019 10:21:24 +0100 Subject: [PATCH] fixing PYTORCH_PRETRAINED_BERT_CACHE use in examples --- examples/run_classifier.py | 2 +- examples/run_squad.py | 2 +- examples/run_swag.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index ca3d1d2e70..f023a4cf20 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -495,7 +495,7 @@ def main(): num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() # Prepare model - cache_dir = args.cache_dir if args.cache_dir else os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)) + cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)) model = BertForSequenceClassification.from_pretrained(args.bert_model, cache_dir=cache_dir, num_labels = num_labels) diff --git a/examples/run_squad.py b/examples/run_squad.py index 4b2c1fd237..a29362ffb9 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -894,7 +894,7 @@ def main(): # Prepare model model = BertForQuestionAnswering.from_pretrained(args.bert_model, - cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank))) + cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank))) if args.fp16: model.half() diff --git a/examples/run_swag.py b/examples/run_swag.py index 12761468b0..2ebd637223 100644 --- a/examples/run_swag.py +++ b/examples/run_swag.py @@ -367,7 +367,7 @@ def main(): # Prepare model model = BertForMultipleChoice.from_pretrained(args.bert_model, - cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)), + cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)), num_choices=4) if args.fp16: model.half()