fix: non-atomic checkpoint save (#27820)
This commit is contained in:
@@ -79,7 +79,8 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend, get_last_checkpoint
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
@@ -1310,6 +1311,19 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer.train()
|
||||
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False)
|
||||
|
||||
def test_save_checkpoints_is_atomic(self):
|
||||
class UnsaveableTokenizer(PreTrainedTokenizerBase):
|
||||
def save_pretrained(self, *args, **kwargs):
|
||||
raise OSError("simulated file write error")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5)
|
||||
# Attach unsaveable tokenizer to partially fail checkpointing
|
||||
trainer.tokenizer = UnsaveableTokenizer()
|
||||
with self.assertRaises(OSError) as _context:
|
||||
trainer.train()
|
||||
assert get_last_checkpoint(tmpdir) is None
|
||||
|
||||
@require_safetensors
|
||||
def test_safe_checkpoints(self):
|
||||
for save_safetensors in [True, False]:
|
||||
|
||||
Reference in New Issue
Block a user