From abd397e95483cc7486bed47fd271ab417e1a288e Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Tue, 5 Nov 2019 10:57:24 -0500 Subject: [PATCH] uniformize w/ the cache_dir update --- examples/run_xnli.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/examples/run_xnli.py b/examples/run_xnli.py index 0c1743e6c4..e8457e542b 100644 --- a/examples/run_xnli.py +++ b/examples/run_xnli.py @@ -467,9 +467,17 @@ def main(): args.model_type = args.model_type.lower() config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] - config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name) - tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case) - model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config) + config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, + num_labels=num_labels, + finetuning_task=args.task_name, + cache_dir=args.cache_dir if args.cache_dir else None) + tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, + do_lower_case=args.do_lower_case, + cache_dir=args.cache_dir if args.cache_dir else None) + model = model_class.from_pretrained(args.model_name_or_path, + from_tf=bool('.ckpt' in args.model_name_or_path), + config=config, + cache_dir=args.cache_dir if args.cache_dir else None) if args.local_rank == 0: torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab