Enforce saving at end of training if saving option chosen (#30160)
* Enforce saving at end of training * Fix test * Rework test * Fixup tests' * Update comment based on sourab feedback * Clean
This commit is contained in:
@@ -544,6 +544,9 @@ class DefaultFlowCallback(TrainerCallback):
|
|||||||
# End training
|
# End training
|
||||||
if state.global_step >= state.max_steps:
|
if state.global_step >= state.max_steps:
|
||||||
control.should_training_stop = True
|
control.should_training_stop = True
|
||||||
|
# Save the model at the end if we have a save strategy
|
||||||
|
if args.save_strategy != IntervalStrategy.NO:
|
||||||
|
control.should_save = True
|
||||||
|
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|||||||
@@ -335,6 +335,9 @@ class TrainingArguments:
|
|||||||
- `"no"`: No save is done during training.
|
- `"no"`: No save is done during training.
|
||||||
- `"epoch"`: Save is done at the end of each epoch.
|
- `"epoch"`: Save is done at the end of each epoch.
|
||||||
- `"steps"`: Save is done every `save_steps`.
|
- `"steps"`: Save is done every `save_steps`.
|
||||||
|
|
||||||
|
If `"epoch"` or `"steps"` is chosen, saving will also be performed at the
|
||||||
|
very end of training, always.
|
||||||
save_steps (`int` or `float`, *optional*, defaults to 500):
|
save_steps (`int` or `float`, *optional*, defaults to 500):
|
||||||
Number of updates steps before two checkpoint saves if `save_strategy="steps"`. Should be an integer or a
|
Number of updates steps before two checkpoint saves if `save_strategy="steps"`. Should be an integer or a
|
||||||
float in range `[0,1)`. If smaller than 1, will be interpreted as ratio of total training steps.
|
float in range `[0,1)`. If smaller than 1, will be interpreted as ratio of total training steps.
|
||||||
|
|||||||
@@ -129,6 +129,7 @@ if is_torch_available():
|
|||||||
if is_safetensors_available():
|
if is_safetensors_available():
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
|
||||||
|
|
||||||
# for version specific tests in TrainerIntegrationTest
|
# for version specific tests in TrainerIntegrationTest
|
||||||
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
|
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
|
||||||
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
|
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
|
||||||
@@ -2016,6 +2017,56 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False, safe_weights=save_safetensors
|
tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False, safe_weights=save_safetensors
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_load_best_model_with_save(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
output_dir=tmpdir,
|
||||||
|
save_steps=5,
|
||||||
|
evaluation_strategy="steps",
|
||||||
|
eval_steps=5,
|
||||||
|
max_steps=9,
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
# Check that we have the last known step:
|
||||||
|
assert os.path.exists(
|
||||||
|
os.path.join(tmpdir, f"checkpoint-{trainer.state.max_steps}")
|
||||||
|
), f"Could not find checkpoint-{trainer.state.max_steps}"
|
||||||
|
# And then check the last step
|
||||||
|
assert os.path.exists(os.path.join(tmpdir, "checkpoint-9")), "Could not find checkpoint-9"
|
||||||
|
|
||||||
|
# Now test that using a limit works
|
||||||
|
# Should result in:
|
||||||
|
# - save at step 5 (but is deleted)
|
||||||
|
# - save at step 10 (loaded in at the end when `load_best_model=True`)
|
||||||
|
# - save at step 11
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
output_dir=tmpdir,
|
||||||
|
save_steps=5,
|
||||||
|
evaluation_strategy="steps",
|
||||||
|
eval_steps=5,
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
save_total_limit=2,
|
||||||
|
max_steps=11,
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
# Check that we have the last known step:
|
||||||
|
assert os.path.exists(os.path.join(tmpdir, "checkpoint-11")), "Could not find checkpoint-11"
|
||||||
|
# And then check the last multiple
|
||||||
|
assert os.path.exists(os.path.join(tmpdir, "checkpoint-10")), "Could not find checkpoint-10"
|
||||||
|
# Finally check that we don't have an old one
|
||||||
|
assert not os.path.exists(os.path.join(tmpdir, "checkpoint-5")), "Found checkpoint-5, limit not respected"
|
||||||
|
|
||||||
|
# Finally check that the right model was loaded in, checkpoint-10
|
||||||
|
# this goes by the last `eval` step check to do so, so it won't be
|
||||||
|
# the last model *saved*
|
||||||
|
model_state = trainer.model.state_dict()
|
||||||
|
final_model_weights = safetensors.torch.load_file(
|
||||||
|
os.path.join(tmpdir, "checkpoint-10", "model.safetensors")
|
||||||
|
)
|
||||||
|
for k, v in model_state.items():
|
||||||
|
assert torch.allclose(v, final_model_weights[k]), f"{k} is not the same"
|
||||||
|
|
||||||
@require_torch_multi_accelerator
|
@require_torch_multi_accelerator
|
||||||
def test_run_seq2seq_double_train_wrap_once(self):
|
def test_run_seq2seq_double_train_wrap_once(self):
|
||||||
# test that we don't wrap the model more than once
|
# test that we don't wrap the model more than once
|
||||||
|
|||||||
@@ -153,7 +153,7 @@ class TrainerCallbackTest(unittest.TestCase):
|
|||||||
expected_events.append("on_log")
|
expected_events.append("on_log")
|
||||||
if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0:
|
if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0:
|
||||||
expected_events += evaluation_events.copy()
|
expected_events += evaluation_events.copy()
|
||||||
if step % trainer.args.save_steps == 0:
|
if step % trainer.args.save_steps == 0 or step == trainer.state.max_steps:
|
||||||
expected_events.append("on_save")
|
expected_events.append("on_save")
|
||||||
expected_events.append("on_epoch_end")
|
expected_events.append("on_epoch_end")
|
||||||
if trainer.args.eval_strategy == IntervalStrategy.EPOCH:
|
if trainer.args.eval_strategy == IntervalStrategy.EPOCH:
|
||||||
|
|||||||
Reference in New Issue
Block a user