Extract method
This commit is contained in:
committed by
Lysandre Debut
parent
d6c5469712
commit
56301bd9e8
@@ -106,6 +106,26 @@ def set_seed(args):
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def rotate_checkpoints(args):
|
||||
if args.save_total_limit and args.save_total_limit > 0:
|
||||
# Check if we should delete older checkpoint(s)
|
||||
glob_checkpoints = glob.glob(os.path.join(args.output_dir, 'checkpoint-*'))
|
||||
if len(glob_checkpoints) > args.save_total_limit:
|
||||
checkpoints_sorted = []
|
||||
for path in glob_checkpoints:
|
||||
regex_match = re.match('.*checkpoint-([0-9]+)', path)
|
||||
if regex_match and regex_match.groups():
|
||||
checkpoints_sorted.append((int(regex_match.groups()[0]), path))
|
||||
|
||||
checkpoints_sorted = sorted(checkpoints_sorted)
|
||||
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
|
||||
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
|
||||
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
|
||||
for checkpoint in checkpoints_to_be_deleted:
|
||||
logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
|
||||
shutil.rmtree(checkpoint)
|
||||
|
||||
|
||||
def mask_tokens(inputs, tokenizer, args):
|
||||
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
|
||||
labels = inputs.clone()
|
||||
@@ -233,23 +253,7 @@ def train(args, train_dataset, model, tokenizer):
|
||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
if args.save_total_limit and args.save_total_limit > 0:
|
||||
# Check if we should delete older checkpoint(s)
|
||||
glob_checkpoints = glob.glob(os.path.join(args.output_dir, 'checkpoint-*'))
|
||||
if len(glob_checkpoints) > args.save_total_limit:
|
||||
checkpoints_sorted = []
|
||||
for path in glob_checkpoints:
|
||||
regex_match = re.match('.*checkpoint-([0-9]+)', path)
|
||||
if regex_match and regex_match.groups():
|
||||
checkpoints_sorted.append((int(regex_match.groups()[0]), path))
|
||||
|
||||
checkpoints_sorted = sorted(checkpoints_sorted)
|
||||
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
|
||||
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
|
||||
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
|
||||
for checkpoint in checkpoints_to_be_deleted:
|
||||
logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
|
||||
shutil.rmtree(checkpoint)
|
||||
rotate_checkpoints(args)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
epoch_iterator.close()
|
||||
|
||||
Reference in New Issue
Block a user