update cache_dir in readme and examples
This commit is contained in:
@@ -482,7 +482,8 @@ def main():
|
||||
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
|
||||
|
||||
# Prepare model
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, len(label_list))
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, len(label_list),
|
||||
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))
|
||||
if args.fp16:
|
||||
model.half()
|
||||
model.to(device)
|
||||
|
||||
Reference in New Issue
Block a user