From 528d3f327b16c4bc2115f05d883e7d8eafe3e5e6 Mon Sep 17 00:00:00 2001 From: jinoobaek-qz Date: Mon, 7 Oct 2019 15:25:09 -0700 Subject: [PATCH] Improve readability and improve make less assumptions about checkpoint format --- examples/run_lm_finetuning.py | 44 +++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/examples/run_lm_finetuning.py b/examples/run_lm_finetuning.py index 6fb441db7a..09ae833f37 100644 --- a/examples/run_lm_finetuning.py +++ b/examples/run_lm_finetuning.py @@ -106,24 +106,31 @@ def set_seed(args): torch.cuda.manual_seed_all(args.seed) -def rotate_checkpoints(args): - if args.save_total_limit and args.save_total_limit > 0: - # Check if we should delete older checkpoint(s) - glob_checkpoints = glob.glob(os.path.join(args.output_dir, 'checkpoint-*')) - if len(glob_checkpoints) > args.save_total_limit: - checkpoints_sorted = [] - for path in glob_checkpoints: - regex_match = re.match('.*checkpoint-([0-9]+)', path) - if regex_match and regex_match.groups(): +def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False): + if not args.save_total_limit: + return + if args.save_total_limit <= 0: + return + + # Check if we should delete older checkpoint(s) + glob_checkpoints = glob.glob(os.path.join(args.output_dir, '{}-*'.format(checkpoint_prefix))) + if len(glob_checkpoints) > args.save_total_limit: + checkpoints_sorted = [] + for path in glob_checkpoints: + regex_match = re.match('.*{}-([0-9]+)'.format(checkpoint_prefix), path) + if regex_match and regex_match.groups(): + if use_mtime: + checkpoints_sorted.append((os.path.getmtime(path), path)) + else: checkpoints_sorted.append((int(regex_match.groups()[0]), path)) - checkpoints_sorted = sorted(checkpoints_sorted) - checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] - number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit) - checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] - for checkpoint in checkpoints_to_be_deleted: - logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint)) - shutil.rmtree(checkpoint) + checkpoints_sorted = sorted(checkpoints_sorted) + checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] + number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit) + checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] + for checkpoint in checkpoints_to_be_deleted: + logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint)) + shutil.rmtree(checkpoint) def mask_tokens(inputs, tokenizer, args): @@ -244,8 +251,9 @@ def train(args, train_dataset, model, tokenizer): logging_loss = tr_loss if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: + checkpoint_prefix = 'checkpoint' # Save model checkpoint - output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) + output_dir = os.path.join(args.output_dir, '{}-{}'.format(checkpoint_prefix, global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training @@ -253,7 +261,7 @@ def train(args, train_dataset, model, tokenizer): torch.save(args, os.path.join(output_dir, 'training_args.bin')) logger.info("Saving model checkpoint to %s", output_dir) - rotate_checkpoints(args) + _rotate_checkpoints(args, checkpoint_prefix) if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close()