Fix #1623
This commit is contained in:
@@ -506,9 +506,15 @@ def main():
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
|
||||
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
model = model_class.from_pretrained(args.model_name_or_path,
|
||||
from_tf=bool('.ckpt' in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
|
||||
if args.teacher_type is not None:
|
||||
assert args.teacher_name_or_path is not None
|
||||
@@ -516,8 +522,11 @@ def main():
|
||||
assert args.alpha_ce + args.alpha_squad > 0.
|
||||
assert args.teacher_type != 'distilbert', "We constraint teachers not to be of type DistilBERT."
|
||||
teacher_config_class, teacher_model_class, _ = MODEL_CLASSES[args.teacher_type]
|
||||
teacher_config = teacher_config_class.from_pretrained(args.teacher_name_or_path)
|
||||
teacher = teacher_model_class.from_pretrained(args.teacher_name_or_path, config=teacher_config)
|
||||
teacher_config = teacher_config_class.from_pretrained(args.teacher_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
teacher = teacher_model_class.from_pretrained(args.teacher_name_or_path,
|
||||
config=teacher_config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
teacher.to(args.device)
|
||||
else:
|
||||
teacher = None
|
||||
@@ -553,8 +562,10 @@ def main():
|
||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
model = model_class.from_pretrained(args.output_dir, cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
model.to(args.device)
|
||||
|
||||
|
||||
@@ -571,7 +582,7 @@ def main():
|
||||
for checkpoint in checkpoints:
|
||||
# Reload the model
|
||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model = model_class.from_pretrained(checkpoint, cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluate
|
||||
|
||||
Reference in New Issue
Block a user