Add save_total_limit
This commit is contained in:
committed by
Lysandre Debut
parent
1c5079952f
commit
54a31f50fb
@@ -27,6 +27,8 @@ import logging
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import re
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -222,6 +224,24 @@ def train(args, train_dataset, model, tokenizer):
|
||||
logging_loss = tr_loss
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
if args.save_total_limit and args.save_total_limit > 0:
|
||||
# Check if we should delete older checkpoint(s)
|
||||
glob_checkpoints = glob.glob(os.path.join(args.output_dir, 'checkpoint-*'))
|
||||
if len(glob_checkpoints) + 1 > args.save_total_limit:
|
||||
checkpoints_sorted = []
|
||||
for path in glob_checkpoints:
|
||||
regex_match = re.match('.*checkpoint-([0-9]+)', path)
|
||||
if regex_match and regex_match.groups():
|
||||
checkpoints_sorted.append((int(regex_match.groups()[0]), path))
|
||||
|
||||
checkpoints_sorted = sorted(checkpoints_sorted)
|
||||
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
|
||||
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) + 1 - args.save_total_limit)
|
||||
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
|
||||
for checkpoint in checkpoints_to_be_deleted:
|
||||
logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
|
||||
shutil.rmtree(checkpoint)
|
||||
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
@@ -359,6 +379,8 @@ def main():
|
||||
help="Log every X updates steps.")
|
||||
parser.add_argument('--save_steps', type=int, default=50,
|
||||
help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument('--save_total_limit', type=int, default=None,
|
||||
help='Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default')
|
||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
|
||||
Reference in New Issue
Block a user