fixing PYTORCH_PRETRAINED_BERT_CACHE use in examples

This commit is contained in:
thomwolf
2019-03-06 10:21:24 +01:00
parent 2dd8f524f5
commit 994d86609b
3 changed files with 3 additions and 3 deletions

View File

@@ -495,7 +495,7 @@ def main():
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare model # Prepare model
cache_dir = args.cache_dir if args.cache_dir else os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)) cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank))
model = BertForSequenceClassification.from_pretrained(args.bert_model, model = BertForSequenceClassification.from_pretrained(args.bert_model,
cache_dir=cache_dir, cache_dir=cache_dir,
num_labels = num_labels) num_labels = num_labels)

View File

@@ -894,7 +894,7 @@ def main():
# Prepare model # Prepare model
model = BertForQuestionAnswering.from_pretrained(args.bert_model, model = BertForQuestionAnswering.from_pretrained(args.bert_model,
cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank))) cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)))
if args.fp16: if args.fp16:
model.half() model.half()

View File

@@ -367,7 +367,7 @@ def main():
# Prepare model # Prepare model
model = BertForMultipleChoice.from_pretrained(args.bert_model, model = BertForMultipleChoice.from_pretrained(args.bert_model,
cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)), cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)),
num_choices=4) num_choices=4)
if args.fp16: if args.fp16:
model.half() model.half()