Fix checkpoint deletion (#11748)
This commit is contained in:
@@ -21,6 +21,7 @@ import random
|
||||
import re
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -45,6 +46,7 @@ from transformers.testing_utils import (
|
||||
require_torch_multi_gpu,
|
||||
slow,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.utils.hp_naming import TrialShortNamer
|
||||
|
||||
|
||||
@@ -1048,6 +1050,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer.train()
|
||||
self.assertTrue(isinstance(trainer.state.total_flos, float))
|
||||
|
||||
def check_checkpoint_deletion(self, trainer, output_dir, expected):
|
||||
# Make fake checkpoints
|
||||
for n in [5, 10, 15, 20, 25]:
|
||||
os.makedirs(os.path.join(output_dir, f"{PREFIX_CHECKPOINT_DIR}-{n}"), exist_ok=True)
|
||||
trainer._rotate_checkpoints(output_dir=output_dir)
|
||||
glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{PREFIX_CHECKPOINT_DIR}-*")]
|
||||
values = [int(re.match(f".*{PREFIX_CHECKPOINT_DIR}-([0-9]+)", d).groups()[0]) for d in glob_checkpoints]
|
||||
self.assertSetEqual(set(values), set(expected))
|
||||
|
||||
def test_checkpoint_rotation(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Without best model at end
|
||||
trainer = get_regression_trainer(output_dir=tmp_dir, save_total_limit=2)
|
||||
self.check_checkpoint_deletion(trainer, tmp_dir, [20, 25])
|
||||
|
||||
# With best model at end
|
||||
trainer = get_regression_trainer(output_dir=tmp_dir, load_best_model_at_end=True, save_total_limit=2)
|
||||
trainer.state.best_model_checkpoint = os.path.join(tmp_dir, "checkpoint-5")
|
||||
self.check_checkpoint_deletion(trainer, tmp_dir, [5, 25])
|
||||
|
||||
# Edge case: we don't always honor save_total_limit=1 if load_best_model_at_end=True to be able to resume
|
||||
# from checkpoint
|
||||
trainer = get_regression_trainer(output_dir=tmp_dir, load_best_model_at_end=True, save_total_limit=1)
|
||||
trainer.state.best_model_checkpoint = os.path.join(tmp_dir, "checkpoint-25")
|
||||
self.check_checkpoint_deletion(trainer, tmp_dir, [25])
|
||||
|
||||
trainer.state.best_model_checkpoint = os.path.join(tmp_dir, "checkpoint-5")
|
||||
self.check_checkpoint_deletion(trainer, tmp_dir, [5, 25])
|
||||
|
||||
def check_mem_metrics(self, trainer, check_func):
|
||||
metrics = trainer.train().metrics
|
||||
check_func("init_mem_cpu_alloc_delta", metrics)
|
||||
|
||||
Reference in New Issue
Block a user