Add warnings if training args differ from checkpoint trainer state (#29255)

* add warnings if training args differ from checkpoint args stored in trainer_state.json

* run formatting and styling

* add a test

* format and styling

---------

Co-authored-by: Jonathan Flynn <jonl.flynn@guardian.co.uk>
This commit is contained in:
Jonathan Flynn
2024-03-26 06:13:13 +00:00
committed by GitHub
parent 7eb3ba8224
commit b5a6d6eeab
2 changed files with 64 additions and 0 deletions

View File

@@ -2485,6 +2485,46 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer.state.best_model_checkpoint = os.path.join(tmp_dir, "checkpoint-5")
self.check_checkpoint_deletion(trainer, tmp_dir, [5, 25])
def test_compare_trainer_and_checkpoint_args_logging(self):
logger = logging.get_logger()
with tempfile.TemporaryDirectory() as tmpdir, CaptureLogger(logger) as cl:
trainer = get_regression_trainer(
output_dir=tmpdir,
train_len=128,
eval_steps=5,
gradient_accumulation_steps=2,
per_device_train_batch_size=4,
save_steps=5,
learning_rate=0.1,
)
trainer.train()
checkpoint = os.path.join(tmpdir, "checkpoint-5")
checkpoint_trainer = get_regression_trainer(
output_dir=tmpdir,
train_len=256,
eval_steps=10,
gradient_accumulation_steps=4,
per_device_train_batch_size=8,
save_steps=10,
learning_rate=0.1,
)
checkpoint_trainer.train(resume_from_checkpoint=checkpoint)
self.assertIn(
"Warning: The training argument 'save_steps' value (10) does not match the trainer state 'save_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
cl.out,
)
self.assertIn(
"Warning: The training argument 'per_device_train_batch_size' value (8) does not match the trainer state 'train_batch_size' value (4). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
cl.out,
)
self.assertIn(
"Warning: The training argument 'eval_steps' value (10) does not match the trainer state 'eval_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
cl.out,
)
def check_mem_metrics(self, trainer, check_func):
metrics = trainer.train().metrics
check_func("init_mem_cpu_alloc_delta", metrics)