Rework tests to compare trainer checkpoint args (#29883)

* Start rework

* Fix failing test

* Include max

* Update src/transformers/trainer.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Zach Mueller
2024-03-30 22:19:17 -04:00
committed by GitHub
parent 6e584070d4
commit 3b8e2932ce
2 changed files with 18 additions and 15 deletions

View File

@@ -2540,16 +2540,14 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
)
checkpoint_trainer.train(resume_from_checkpoint=checkpoint)
self.assertIn("save_steps: 10 (from args) != 5 (from trainer_state.json)", cl.out)
self.assertIn(
"Warning: The training argument 'save_steps' value (10) does not match the trainer state 'save_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
"per_device_train_batch_size: 8 (from args) != 4 (from trainer_state.json)",
cl.out,
)
self.assertIn(
"Warning: The training argument 'per_device_train_batch_size' value (8) does not match the trainer state 'train_batch_size' value (4). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
cl.out,
)
self.assertIn(
"Warning: The training argument 'eval_steps' value (10) does not match the trainer state 'eval_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
"eval_steps: 10 (from args) != 5 (from trainer_state.json)",
cl.out,
)