update cache_dir in readme and examples

This commit is contained in:
thomwolf
2018-11-26 10:45:13 +01:00
parent 63ae5d2134
commit 05053d163c
4 changed files with 8 additions and 5 deletions

View File

@@ -482,7 +482,8 @@ def main():
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
# Prepare model
model = BertForSequenceClassification.from_pretrained(args.bert_model, len(label_list))
model = BertForSequenceClassification.from_pretrained(args.bert_model, len(label_list),
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))
if args.fp16:
model.half()
model.to(device)