diff --git a/examples/run_squad.py b/examples/run_squad.py index 117b86e32c..a39915ee8b 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -580,10 +580,16 @@ def main(): # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory results = {} if args.do_eval and args.local_rank in [-1, 0]: - checkpoints = [args.output_dir] - if args.eval_all_checkpoints: - checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) - logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs + + if args.do_train: + logger.info("Loading checkpoints saved during training for evaluation") + checkpoints = [args.output_dir] + if args.eval_all_checkpoints: + checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) + logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs + else: + logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path) + checkpoints = [args.model_name_or_path] logger.info("Evaluate the following checkpoints: %s", checkpoints)