Enforce saving at end of training if saving option chosen (#30160)
* Enforce saving at end of training * Fix test * Rework test * Fixup tests' * Update comment based on sourab feedback * Clean
This commit is contained in:
@@ -129,6 +129,7 @@ if is_torch_available():
|
||||
if is_safetensors_available():
|
||||
import safetensors.torch
|
||||
|
||||
|
||||
# for version specific tests in TrainerIntegrationTest
|
||||
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
|
||||
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
|
||||
@@ -2016,6 +2017,56 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False, safe_weights=save_safetensors
|
||||
)
|
||||
|
||||
def test_load_best_model_with_save(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir,
|
||||
save_steps=5,
|
||||
evaluation_strategy="steps",
|
||||
eval_steps=5,
|
||||
max_steps=9,
|
||||
)
|
||||
trainer.train()
|
||||
# Check that we have the last known step:
|
||||
assert os.path.exists(
|
||||
os.path.join(tmpdir, f"checkpoint-{trainer.state.max_steps}")
|
||||
), f"Could not find checkpoint-{trainer.state.max_steps}"
|
||||
# And then check the last step
|
||||
assert os.path.exists(os.path.join(tmpdir, "checkpoint-9")), "Could not find checkpoint-9"
|
||||
|
||||
# Now test that using a limit works
|
||||
# Should result in:
|
||||
# - save at step 5 (but is deleted)
|
||||
# - save at step 10 (loaded in at the end when `load_best_model=True`)
|
||||
# - save at step 11
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir,
|
||||
save_steps=5,
|
||||
evaluation_strategy="steps",
|
||||
eval_steps=5,
|
||||
load_best_model_at_end=True,
|
||||
save_total_limit=2,
|
||||
max_steps=11,
|
||||
)
|
||||
trainer.train()
|
||||
# Check that we have the last known step:
|
||||
assert os.path.exists(os.path.join(tmpdir, "checkpoint-11")), "Could not find checkpoint-11"
|
||||
# And then check the last multiple
|
||||
assert os.path.exists(os.path.join(tmpdir, "checkpoint-10")), "Could not find checkpoint-10"
|
||||
# Finally check that we don't have an old one
|
||||
assert not os.path.exists(os.path.join(tmpdir, "checkpoint-5")), "Found checkpoint-5, limit not respected"
|
||||
|
||||
# Finally check that the right model was loaded in, checkpoint-10
|
||||
# this goes by the last `eval` step check to do so, so it won't be
|
||||
# the last model *saved*
|
||||
model_state = trainer.model.state_dict()
|
||||
final_model_weights = safetensors.torch.load_file(
|
||||
os.path.join(tmpdir, "checkpoint-10", "model.safetensors")
|
||||
)
|
||||
for k, v in model_state.items():
|
||||
assert torch.allclose(v, final_model_weights[k]), f"{k} is not the same"
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
def test_run_seq2seq_double_train_wrap_once(self):
|
||||
# test that we don't wrap the model more than once
|
||||
|
||||
@@ -153,7 +153,7 @@ class TrainerCallbackTest(unittest.TestCase):
|
||||
expected_events.append("on_log")
|
||||
if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0:
|
||||
expected_events += evaluation_events.copy()
|
||||
if step % trainer.args.save_steps == 0:
|
||||
if step % trainer.args.save_steps == 0 or step == trainer.state.max_steps:
|
||||
expected_events.append("on_save")
|
||||
expected_events.append("on_epoch_end")
|
||||
if trainer.args.eval_strategy == IntervalStrategy.EPOCH:
|
||||
|
||||
Reference in New Issue
Block a user