From bf34a252b812237b804f66529d0edf8b81707c8a Mon Sep 17 00:00:00 2001 From: jinoobaek-qz Date: Mon, 7 Oct 2019 15:26:57 -0700 Subject: [PATCH] Golden path --- examples/run_lm_finetuning.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/examples/run_lm_finetuning.py b/examples/run_lm_finetuning.py index 09ae833f37..eb154e8180 100644 --- a/examples/run_lm_finetuning.py +++ b/examples/run_lm_finetuning.py @@ -114,23 +114,25 @@ def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False): # Check if we should delete older checkpoint(s) glob_checkpoints = glob.glob(os.path.join(args.output_dir, '{}-*'.format(checkpoint_prefix))) - if len(glob_checkpoints) > args.save_total_limit: - checkpoints_sorted = [] - for path in glob_checkpoints: - regex_match = re.match('.*{}-([0-9]+)'.format(checkpoint_prefix), path) - if regex_match and regex_match.groups(): - if use_mtime: - checkpoints_sorted.append((os.path.getmtime(path), path)) - else: - checkpoints_sorted.append((int(regex_match.groups()[0]), path)) + if len(glob_checkpoints) <= args.save_total_limit: + return - checkpoints_sorted = sorted(checkpoints_sorted) - checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] - number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit) - checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] - for checkpoint in checkpoints_to_be_deleted: - logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint)) - shutil.rmtree(checkpoint) + checkpoints_sorted = [] + for path in glob_checkpoints: + regex_match = re.match('.*{}-([0-9]+)'.format(checkpoint_prefix), path) + if regex_match and regex_match.groups(): + if use_mtime: + checkpoints_sorted.append((os.path.getmtime(path), path)) + else: + checkpoints_sorted.append((int(regex_match.groups()[0]), path)) + + checkpoints_sorted = sorted(checkpoints_sorted) + checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] + number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit) + checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] + for checkpoint in checkpoints_to_be_deleted: + logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint)) + shutil.rmtree(checkpoint) def mask_tokens(inputs, tokenizer, args):