From c8ed1c82c8a42ef700d4129d227fa356385c1d60 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 13 Dec 2019 12:13:48 -0500 Subject: [PATCH] [SQUAD] Load checkpoint when evaluating without training --- examples/run_squad.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/run_squad.py b/examples/run_squad.py index 117b86e32c..a39915ee8b 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -580,10 +580,16 @@ def main(): # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory results = {} if args.do_eval and args.local_rank in [-1, 0]: - checkpoints = [args.output_dir] - if args.eval_all_checkpoints: - checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) - logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs + + if args.do_train: + logger.info("Loading checkpoints saved during training for evaluation") + checkpoints = [args.output_dir] + if args.eval_all_checkpoints: + checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) + logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs + else: + logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path) + checkpoints = [args.model_name_or_path] logger.info("Evaluate the following checkpoints: %s", checkpoints)