uniformize w/ the cache_dir update
This commit is contained in:
committed by
Lysandre Debut
parent
d75d49a51d
commit
abd397e954
@@ -467,9 +467,17 @@ def main():
|
|||||||
|
|
||||||
args.model_type = args.model_type.lower()
|
args.model_type = args.model_type.lower()
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name)
|
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
|
num_labels=num_labels,
|
||||||
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
|
finetuning_task=args.task_name,
|
||||||
|
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||||
|
do_lower_case=args.do_lower_case,
|
||||||
|
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||||
|
model = model_class.from_pretrained(args.model_name_or_path,
|
||||||
|
from_tf=bool('.ckpt' in args.model_name_or_path),
|
||||||
|
config=config,
|
||||||
|
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
|
|||||||
Reference in New Issue
Block a user