[SQUAD] Load checkpoint when evaluating without training
This commit is contained in:
@@ -580,10 +580,16 @@ def main():
|
|||||||
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
||||||
results = {}
|
results = {}
|
||||||
if args.do_eval and args.local_rank in [-1, 0]:
|
if args.do_eval and args.local_rank in [-1, 0]:
|
||||||
checkpoints = [args.output_dir]
|
|
||||||
if args.eval_all_checkpoints:
|
if args.do_train:
|
||||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
logger.info("Loading checkpoints saved during training for evaluation")
|
||||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
|
checkpoints = [args.output_dir]
|
||||||
|
if args.eval_all_checkpoints:
|
||||||
|
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||||
|
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
|
||||||
|
else:
|
||||||
|
logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
|
||||||
|
checkpoints = [args.model_name_or_path]
|
||||||
|
|
||||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user