@@ -434,13 +434,13 @@ class Trainer:
|
|||||||
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
|
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
|
||||||
ordering_and_checkpoint_path = []
|
ordering_and_checkpoint_path = []
|
||||||
|
|
||||||
glob_checkpoints = Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")
|
glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")]
|
||||||
|
|
||||||
for path in glob_checkpoints:
|
for path in glob_checkpoints:
|
||||||
if use_mtime:
|
if use_mtime:
|
||||||
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
|
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
|
||||||
else:
|
else:
|
||||||
regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path)
|
regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
|
||||||
if regex_match and regex_match.groups():
|
if regex_match and regex_match.groups():
|
||||||
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
|
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
|
||||||
|
|
||||||
@@ -449,9 +449,7 @@ class Trainer:
|
|||||||
return checkpoints_sorted
|
return checkpoints_sorted
|
||||||
|
|
||||||
def _rotate_checkpoints(self, use_mtime=False) -> None:
|
def _rotate_checkpoints(self, use_mtime=False) -> None:
|
||||||
if not self.args.save_total_limit:
|
if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
|
||||||
return
|
|
||||||
if self.args.save_total_limit <= 0:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if we should delete older checkpoint(s)
|
# Check if we should delete older checkpoint(s)
|
||||||
|
|||||||
Reference in New Issue
Block a user