diff --git a/examples/run_lm_finetuning.py b/examples/run_lm_finetuning.py index c167703d7b..7e8fd74f64 100644 --- a/examples/run_lm_finetuning.py +++ b/examples/run_lm_finetuning.py @@ -27,6 +27,8 @@ import logging import os import pickle import random +import re +import shutil import numpy as np import torch @@ -222,6 +224,24 @@ def train(args, train_dataset, model, tokenizer): logging_loss = tr_loss if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: + if args.save_total_limit and args.save_total_limit > 0: + # Check if we should delete older checkpoint(s) + glob_checkpoints = glob.glob(os.path.join(args.output_dir, 'checkpoint-*')) + if len(glob_checkpoints) + 1 > args.save_total_limit: + checkpoints_sorted = [] + for path in glob_checkpoints: + regex_match = re.match('.*checkpoint-([0-9]+)', path) + if regex_match and regex_match.groups(): + checkpoints_sorted.append((int(regex_match.groups()[0]), path)) + + checkpoints_sorted = sorted(checkpoints_sorted) + checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] + number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) + 1 - args.save_total_limit) + checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] + for checkpoint in checkpoints_to_be_deleted: + logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint)) + shutil.rmtree(checkpoint) + # Save model checkpoint output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) if not os.path.exists(output_dir): @@ -359,6 +379,8 @@ def main(): help="Log every X updates steps.") parser.add_argument('--save_steps', type=int, default=50, help="Save checkpoint every X updates steps.") + parser.add_argument('--save_total_limit', type=int, default=None, + help='Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default') parser.add_argument("--eval_all_checkpoints", action='store_true', help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number") parser.add_argument("--no_cuda", action='store_true',