Fix checkpoint deletion (#11748)
This commit is contained in:
@@ -1523,10 +1523,6 @@ class Trainer:
|
|||||||
if self.is_world_process_zero():
|
if self.is_world_process_zero():
|
||||||
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
|
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
|
||||||
|
|
||||||
# Maybe delete some older checkpoints.
|
|
||||||
if self.is_world_process_zero():
|
|
||||||
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
|
|
||||||
|
|
||||||
# Save RNG state in non-distributed training
|
# Save RNG state in non-distributed training
|
||||||
rng_states = {
|
rng_states = {
|
||||||
"python": random.getstate(),
|
"python": random.getstate(),
|
||||||
@@ -1552,6 +1548,10 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))
|
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))
|
||||||
|
|
||||||
|
# Maybe delete some older checkpoints.
|
||||||
|
if self.is_world_process_zero():
|
||||||
|
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
|
||||||
|
|
||||||
def _load_optimizer_and_scheduler(self, checkpoint):
|
def _load_optimizer_and_scheduler(self, checkpoint):
|
||||||
"""If optimizer and scheduler states exist, load them."""
|
"""If optimizer and scheduler states exist, load them."""
|
||||||
if checkpoint is None:
|
if checkpoint is None:
|
||||||
@@ -1924,7 +1924,7 @@ class Trainer:
|
|||||||
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
|
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
|
||||||
else:
|
else:
|
||||||
regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
|
regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
|
||||||
if regex_match and regex_match.groups():
|
if regex_match is not None and regex_match.groups() is not None:
|
||||||
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
|
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
|
||||||
|
|
||||||
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
|
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
|
||||||
@@ -1932,10 +1932,8 @@ class Trainer:
|
|||||||
# Make sure we don't delete the best model.
|
# Make sure we don't delete the best model.
|
||||||
if self.state.best_model_checkpoint is not None:
|
if self.state.best_model_checkpoint is not None:
|
||||||
best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
|
best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
|
||||||
checkpoints_sorted[best_model_index], checkpoints_sorted[-1] = (
|
for i in range(best_model_index, len(checkpoints_sorted) - 2):
|
||||||
checkpoints_sorted[-1],
|
checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
|
||||||
checkpoints_sorted[best_model_index],
|
|
||||||
)
|
|
||||||
return checkpoints_sorted
|
return checkpoints_sorted
|
||||||
|
|
||||||
def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
|
def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
|
||||||
@@ -1947,7 +1945,17 @@ class Trainer:
|
|||||||
if len(checkpoints_sorted) <= self.args.save_total_limit:
|
if len(checkpoints_sorted) <= self.args.save_total_limit:
|
||||||
return
|
return
|
||||||
|
|
||||||
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit)
|
# If save_total_limit=1 with load_best_mode_at_end=True, we could end up deleting the last checkpoint, which
|
||||||
|
# we don't do to allow resuming.
|
||||||
|
save_total_limit = self.args.save_total_limit
|
||||||
|
if (
|
||||||
|
self.state.best_model_checkpoint is not None
|
||||||
|
and self.args.save_total_limit == 1
|
||||||
|
and checkpoints_sorted[-1] != self.state.best_model_checkpoint
|
||||||
|
):
|
||||||
|
save_total_limit = 2
|
||||||
|
|
||||||
|
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
|
||||||
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
|
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
|
||||||
for checkpoint in checkpoints_to_be_deleted:
|
for checkpoint in checkpoints_to_be_deleted:
|
||||||
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
|
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import random
|
|||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -45,6 +46,7 @@ from transformers.testing_utils import (
|
|||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
from transformers.utils.hp_naming import TrialShortNamer
|
from transformers.utils.hp_naming import TrialShortNamer
|
||||||
|
|
||||||
|
|
||||||
@@ -1048,6 +1050,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
self.assertTrue(isinstance(trainer.state.total_flos, float))
|
self.assertTrue(isinstance(trainer.state.total_flos, float))
|
||||||
|
|
||||||
|
def check_checkpoint_deletion(self, trainer, output_dir, expected):
|
||||||
|
# Make fake checkpoints
|
||||||
|
for n in [5, 10, 15, 20, 25]:
|
||||||
|
os.makedirs(os.path.join(output_dir, f"{PREFIX_CHECKPOINT_DIR}-{n}"), exist_ok=True)
|
||||||
|
trainer._rotate_checkpoints(output_dir=output_dir)
|
||||||
|
glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{PREFIX_CHECKPOINT_DIR}-*")]
|
||||||
|
values = [int(re.match(f".*{PREFIX_CHECKPOINT_DIR}-([0-9]+)", d).groups()[0]) for d in glob_checkpoints]
|
||||||
|
self.assertSetEqual(set(values), set(expected))
|
||||||
|
|
||||||
|
def test_checkpoint_rotation(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
# Without best model at end
|
||||||
|
trainer = get_regression_trainer(output_dir=tmp_dir, save_total_limit=2)
|
||||||
|
self.check_checkpoint_deletion(trainer, tmp_dir, [20, 25])
|
||||||
|
|
||||||
|
# With best model at end
|
||||||
|
trainer = get_regression_trainer(output_dir=tmp_dir, load_best_model_at_end=True, save_total_limit=2)
|
||||||
|
trainer.state.best_model_checkpoint = os.path.join(tmp_dir, "checkpoint-5")
|
||||||
|
self.check_checkpoint_deletion(trainer, tmp_dir, [5, 25])
|
||||||
|
|
||||||
|
# Edge case: we don't always honor save_total_limit=1 if load_best_model_at_end=True to be able to resume
|
||||||
|
# from checkpoint
|
||||||
|
trainer = get_regression_trainer(output_dir=tmp_dir, load_best_model_at_end=True, save_total_limit=1)
|
||||||
|
trainer.state.best_model_checkpoint = os.path.join(tmp_dir, "checkpoint-25")
|
||||||
|
self.check_checkpoint_deletion(trainer, tmp_dir, [25])
|
||||||
|
|
||||||
|
trainer.state.best_model_checkpoint = os.path.join(tmp_dir, "checkpoint-5")
|
||||||
|
self.check_checkpoint_deletion(trainer, tmp_dir, [5, 25])
|
||||||
|
|
||||||
def check_mem_metrics(self, trainer, check_func):
|
def check_mem_metrics(self, trainer, check_func):
|
||||||
metrics = trainer.train().metrics
|
metrics = trainer.train().metrics
|
||||||
check_func("init_mem_cpu_alloc_delta", metrics)
|
check_func("init_mem_cpu_alloc_delta", metrics)
|
||||||
|
|||||||
Reference in New Issue
Block a user