From 4c5ed1d0c942dd4a60e1f99d6636519c40f7c904 Mon Sep 17 00:00:00 2001 From: Jonathon Belotti Date: Fri, 8 Dec 2023 08:08:54 -0500 Subject: [PATCH] fix: non-atomic checkpoint save (#27820) --- src/transformers/trainer.py | 22 +++++++++++++++++----- tests/trainer/test_trainer.py | 16 +++++++++++++++- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 422be2247b..009e24ade0 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2332,13 +2332,21 @@ class Trainer: run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) - self.save_model(output_dir, _internal_call=True) + if os.path.exists(output_dir) and len(os.listdir(output_dir)) > 0: + logger.warning( + f"Checkpoint destination directory {output_dir} already exists and is non-empty." + "Saving will proceed but saved results may be invalid." + ) + staging_output_dir = output_dir + else: + staging_output_dir = os.path.join(run_dir, f"tmp-{checkpoint_folder}") + self.save_model(staging_output_dir, _internal_call=True) if not self.args.save_only_model: # Save optimizer and scheduler - self._save_optimizer_and_scheduler(output_dir) + self._save_optimizer_and_scheduler(staging_output_dir) # Save RNG state - self._save_rng_state(output_dir) + self._save_rng_state(staging_output_dir) # Determine the new best metric / best model checkpoint if metrics is not None and self.args.metric_for_best_model is not None: @@ -2358,10 +2366,14 @@ class Trainer: # Save the Trainer state if self.args.should_save: - self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) + self.state.save_to_json(os.path.join(staging_output_dir, TRAINER_STATE_NAME)) if self.args.push_to_hub: - self._push_from_checkpoint(output_dir) + self._push_from_checkpoint(staging_output_dir) + + # Place checkpoint in final location after all saving is finished. + if staging_output_dir != output_dir: + os.rename(staging_output_dir, output_dir) # Maybe delete some older checkpoints. if self.args.should_save: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 305ccb35d5..129f40fc40 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -79,7 +79,8 @@ from transformers.testing_utils import ( slow, torch_device, ) -from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend +from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend, get_last_checkpoint from transformers.training_args import OptimizerNames from transformers.utils import ( SAFE_WEIGHTS_INDEX_NAME, @@ -1310,6 +1311,19 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): trainer.train() self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False) + def test_save_checkpoints_is_atomic(self): + class UnsaveableTokenizer(PreTrainedTokenizerBase): + def save_pretrained(self, *args, **kwargs): + raise OSError("simulated file write error") + + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5) + # Attach unsaveable tokenizer to partially fail checkpointing + trainer.tokenizer = UnsaveableTokenizer() + with self.assertRaises(OSError) as _context: + trainer.train() + assert get_last_checkpoint(tmpdir) is None + @require_safetensors def test_safe_checkpoints(self): for save_safetensors in [True, False]: