Add warnings if training args differ from checkpoint trainer state (#29255)
* add warnings if training args differ from checkpoint args stored in trainer_state.json * run formatting and styling * add a test * format and styling --------- Co-authored-by: Jonathan Flynn <jonl.flynn@guardian.co.uk>
This commit is contained in:
@@ -2485,6 +2485,46 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer.state.best_model_checkpoint = os.path.join(tmp_dir, "checkpoint-5")
|
||||
self.check_checkpoint_deletion(trainer, tmp_dir, [5, 25])
|
||||
|
||||
def test_compare_trainer_and_checkpoint_args_logging(self):
|
||||
logger = logging.get_logger()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir, CaptureLogger(logger) as cl:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir,
|
||||
train_len=128,
|
||||
eval_steps=5,
|
||||
gradient_accumulation_steps=2,
|
||||
per_device_train_batch_size=4,
|
||||
save_steps=5,
|
||||
learning_rate=0.1,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
||||
checkpoint_trainer = get_regression_trainer(
|
||||
output_dir=tmpdir,
|
||||
train_len=256,
|
||||
eval_steps=10,
|
||||
gradient_accumulation_steps=4,
|
||||
per_device_train_batch_size=8,
|
||||
save_steps=10,
|
||||
learning_rate=0.1,
|
||||
)
|
||||
checkpoint_trainer.train(resume_from_checkpoint=checkpoint)
|
||||
|
||||
self.assertIn(
|
||||
"Warning: The training argument 'save_steps' value (10) does not match the trainer state 'save_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
|
||||
cl.out,
|
||||
)
|
||||
self.assertIn(
|
||||
"Warning: The training argument 'per_device_train_batch_size' value (8) does not match the trainer state 'train_batch_size' value (4). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
|
||||
cl.out,
|
||||
)
|
||||
self.assertIn(
|
||||
"Warning: The training argument 'eval_steps' value (10) does not match the trainer state 'eval_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
|
||||
cl.out,
|
||||
)
|
||||
|
||||
def check_mem_metrics(self, trainer, check_func):
|
||||
metrics = trainer.train().metrics
|
||||
check_func("init_mem_cpu_alloc_delta", metrics)
|
||||
|
||||
Reference in New Issue
Block a user