Improve readability and improve make less assumptions about checkpoint format
This commit is contained in:
committed by
Lysandre Debut
parent
56301bd9e8
commit
528d3f327b
@@ -106,15 +106,22 @@ def set_seed(args):
|
|||||||
torch.cuda.manual_seed_all(args.seed)
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
|
|
||||||
def rotate_checkpoints(args):
|
def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False):
|
||||||
if args.save_total_limit and args.save_total_limit > 0:
|
if not args.save_total_limit:
|
||||||
|
return
|
||||||
|
if args.save_total_limit <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
# Check if we should delete older checkpoint(s)
|
# Check if we should delete older checkpoint(s)
|
||||||
glob_checkpoints = glob.glob(os.path.join(args.output_dir, 'checkpoint-*'))
|
glob_checkpoints = glob.glob(os.path.join(args.output_dir, '{}-*'.format(checkpoint_prefix)))
|
||||||
if len(glob_checkpoints) > args.save_total_limit:
|
if len(glob_checkpoints) > args.save_total_limit:
|
||||||
checkpoints_sorted = []
|
checkpoints_sorted = []
|
||||||
for path in glob_checkpoints:
|
for path in glob_checkpoints:
|
||||||
regex_match = re.match('.*checkpoint-([0-9]+)', path)
|
regex_match = re.match('.*{}-([0-9]+)'.format(checkpoint_prefix), path)
|
||||||
if regex_match and regex_match.groups():
|
if regex_match and regex_match.groups():
|
||||||
|
if use_mtime:
|
||||||
|
checkpoints_sorted.append((os.path.getmtime(path), path))
|
||||||
|
else:
|
||||||
checkpoints_sorted.append((int(regex_match.groups()[0]), path))
|
checkpoints_sorted.append((int(regex_match.groups()[0]), path))
|
||||||
|
|
||||||
checkpoints_sorted = sorted(checkpoints_sorted)
|
checkpoints_sorted = sorted(checkpoints_sorted)
|
||||||
@@ -244,8 +251,9 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
logging_loss = tr_loss
|
logging_loss = tr_loss
|
||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||||
|
checkpoint_prefix = 'checkpoint'
|
||||||
# Save model checkpoint
|
# Save model checkpoint
|
||||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
output_dir = os.path.join(args.output_dir, '{}-{}'.format(checkpoint_prefix, global_step))
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
os.makedirs(output_dir)
|
os.makedirs(output_dir)
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||||
@@ -253,7 +261,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
||||||
logger.info("Saving model checkpoint to %s", output_dir)
|
logger.info("Saving model checkpoint to %s", output_dir)
|
||||||
|
|
||||||
rotate_checkpoints(args)
|
_rotate_checkpoints(args, checkpoint_prefix)
|
||||||
|
|
||||||
if args.max_steps > 0 and global_step > args.max_steps:
|
if args.max_steps > 0 and global_step > args.max_steps:
|
||||||
epoch_iterator.close()
|
epoch_iterator.close()
|
||||||
|
|||||||
Reference in New Issue
Block a user