From c53cc018de70436196858ca91c1a34f1b8947028 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Thu, 23 Apr 2020 23:59:43 +0000 Subject: [PATCH] [Trainer] Fix _rotate_checkpoints Close #3920 --- src/transformers/trainer.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 891e70f0eb..6524ba42ab 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -434,13 +434,13 @@ class Trainer: def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]: ordering_and_checkpoint_path = [] - glob_checkpoints = Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*") + glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")] for path in glob_checkpoints: if use_mtime: ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) else: - regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path) + regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path) if regex_match and regex_match.groups(): ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) @@ -449,9 +449,7 @@ class Trainer: return checkpoints_sorted def _rotate_checkpoints(self, use_mtime=False) -> None: - if not self.args.save_total_limit: - return - if self.args.save_total_limit <= 0: + if self.args.save_total_limit is None or self.args.save_total_limit <= 0: return # Check if we should delete older checkpoint(s)