updated tokenizer loading for addressing reproducibility issues
This commit is contained in:
@@ -448,13 +448,14 @@ def main():
|
|||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = model_class.from_pretrained(args.output_dir)
|
model = model_class.from_pretrained(args.output_dir)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
results = {}
|
results = {}
|
||||||
if args.do_eval and args.local_rank in [-1, 0]:
|
if args.do_eval and args.local_rank in [-1, 0]:
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||||
checkpoints = [args.output_dir]
|
checkpoints = [args.output_dir]
|
||||||
if args.eval_all_checkpoints:
|
if args.eval_all_checkpoints:
|
||||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||||
@@ -463,7 +464,6 @@ def main():
|
|||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
||||||
model = model_class.from_pretrained(checkpoint)
|
model = model_class.from_pretrained(checkpoint)
|
||||||
tokenizer = tokenizer_class.from_pretrained(checkpoint)
|
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
result = evaluate(args, model, tokenizer, prefix=global_step)
|
result = evaluate(args, model, tokenizer, prefix=global_step)
|
||||||
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
||||||
|
|||||||
Reference in New Issue
Block a user