Rework tests to compare trainer checkpoint args (#29883)
* Start rework * Fix failing test * Include max * Update src/transformers/trainer.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -2540,16 +2540,14 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
)
|
||||
checkpoint_trainer.train(resume_from_checkpoint=checkpoint)
|
||||
|
||||
self.assertIn("save_steps: 10 (from args) != 5 (from trainer_state.json)", cl.out)
|
||||
|
||||
self.assertIn(
|
||||
"Warning: The training argument 'save_steps' value (10) does not match the trainer state 'save_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
|
||||
"per_device_train_batch_size: 8 (from args) != 4 (from trainer_state.json)",
|
||||
cl.out,
|
||||
)
|
||||
self.assertIn(
|
||||
"Warning: The training argument 'per_device_train_batch_size' value (8) does not match the trainer state 'train_batch_size' value (4). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
|
||||
cl.out,
|
||||
)
|
||||
self.assertIn(
|
||||
"Warning: The training argument 'eval_steps' value (10) does not match the trainer state 'eval_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
|
||||
"eval_steps: 10 (from args) != 5 (from trainer_state.json)",
|
||||
cl.out,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user