From ba4bce2581f9a67caa44c3cc959a2dacb0090670 Mon Sep 17 00:00:00 2001 From: tuvuumass Date: Tue, 13 Aug 2019 11:26:27 -0400 Subject: [PATCH] fix issue #824 --- examples/run_bertology.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/examples/run_bertology.py b/examples/run_bertology.py index 61c7440ecb..f11b73b54f 100644 --- a/examples/run_bertology.py +++ b/examples/run_bertology.py @@ -211,10 +211,12 @@ def prune_heads(args, model, eval_dataloader, head_mask): def main(): parser = argparse.ArgumentParser() + ## Required parameters parser.add_argument("--data_dir", default=None, type=str, required=True, help="The input data dir. Should contain the .tsv files (or other data files) for the task.") - parser.add_argument("--model_name", default=None, type=str, required=True, - help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS)) + parser.add_argument("--model_name_or_path", default=None, type=str, required=True, + help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join( + ALL_MODELS)) parser.add_argument("--task_name", default=None, type=str, required=True, help="The name of the task to train selected in the list: " + ", ".join(processors.keys())) parser.add_argument("--output_dir", default=None, type=str, required=True, @@ -222,9 +224,9 @@ def main(): ## Other parameters parser.add_argument("--config_name", default="", type=str, - help="Pretrained config name or path if not the same as model_name") + help="Pretrained config name or path if not the same as model_name_or_path") parser.add_argument("--tokenizer_name", default="", type=str, - help="Pretrained tokenizer name or path if not the same as model_name") + help="Pretrained tokenizer name or path if not the same as model_name_or_path") parser.add_argument("--cache_dir", default="", type=str, help="Where do you want to store the pre-trained models downloaded from s3") parser.add_argument("--data_subset", type=int, default=-1, @@ -297,15 +299,15 @@ def main(): args.model_type = "" for key in MODEL_CLASSES: - if key in args.model_name.lower(): + if key in args.model_name_or_path.lower(): args.model_type = key # take the first match in model types break config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] - config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name, + config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name, output_attentions=True) - tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name) - model = model_class.from_pretrained(args.model_name, from_tf=bool('.ckpt' in args.model_name), config=config) + tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path) + model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config) if args.local_rank == 0: torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab