fix issue #824
This commit is contained in:
@@ -211,10 +211,12 @@ def prune_heads(args, model, eval_dataloader, head_mask):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
## Required parameters
|
||||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
||||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
||||||
parser.add_argument("--model_name", default=None, type=str, required=True,
|
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||||
help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
|
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(
|
||||||
|
ALL_MODELS))
|
||||||
parser.add_argument("--task_name", default=None, type=str, required=True,
|
parser.add_argument("--task_name", default=None, type=str, required=True,
|
||||||
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
|
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
|
||||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||||
@@ -222,9 +224,9 @@ def main():
|
|||||||
|
|
||||||
## Other parameters
|
## Other parameters
|
||||||
parser.add_argument("--config_name", default="", type=str,
|
parser.add_argument("--config_name", default="", type=str,
|
||||||
help="Pretrained config name or path if not the same as model_name")
|
help="Pretrained config name or path if not the same as model_name_or_path")
|
||||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
help="Pretrained tokenizer name or path if not the same as model_name_or_path")
|
||||||
parser.add_argument("--cache_dir", default="", type=str,
|
parser.add_argument("--cache_dir", default="", type=str,
|
||||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||||
parser.add_argument("--data_subset", type=int, default=-1,
|
parser.add_argument("--data_subset", type=int, default=-1,
|
||||||
@@ -297,15 +299,15 @@ def main():
|
|||||||
|
|
||||||
args.model_type = ""
|
args.model_type = ""
|
||||||
for key in MODEL_CLASSES:
|
for key in MODEL_CLASSES:
|
||||||
if key in args.model_name.lower():
|
if key in args.model_name_or_path.lower():
|
||||||
args.model_type = key # take the first match in model types
|
args.model_type = key # take the first match in model types
|
||||||
break
|
break
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name,
|
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
||||||
num_labels=num_labels, finetuning_task=args.task_name,
|
num_labels=num_labels, finetuning_task=args.task_name,
|
||||||
output_attentions=True)
|
output_attentions=True)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name)
|
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path)
|
||||||
model = model_class.from_pretrained(args.model_name, from_tf=bool('.ckpt' in args.model_name), config=config)
|
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
|
|||||||
Reference in New Issue
Block a user