Rework tests to compare trainer checkpoint args (#29883)

* Start rework

* Fix failing test

* Include max

* Update src/transformers/trainer.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Zach Mueller
2024-03-30 22:19:17 -04:00
committed by GitHub
parent 6e584070d4
commit 3b8e2932ce
2 changed files with 18 additions and 15 deletions

View File

@@ -1565,23 +1565,28 @@ class Trainer:
"logging_steps": "logging_steps", "logging_steps": "logging_steps",
"eval_steps": "eval_steps", "eval_steps": "eval_steps",
"save_steps": "save_steps", "save_steps": "save_steps",
"per_device_train_batch_size": "train_batch_size",
} }
warnings_list = [] has_warning = False
warning_str = "Warning: The following arguments do not match the ones in the `trainer_state.json` within the checkpoint directory: "
for arg_attr, state_attr in attributes_map.items(): for arg_attr, state_attr in attributes_map.items():
arg_value = getattr(training_args, arg_attr, None) arg_value = getattr(training_args, arg_attr, None)
state_value = getattr(trainer_state, state_attr, None) state_value = getattr(trainer_state, state_attr, None)
if arg_value is not None and state_value is not None and arg_value != state_value: if arg_value is not None and state_value is not None and arg_value != state_value:
warnings_list.append( warning_str += f"\n\t{arg_attr}: {arg_value} (from args) != {state_value} (from trainer_state.json)"
f"Warning: The training argument '{arg_attr}' value ({arg_value}) does not match the trainer state '{state_attr}' value ({state_value}). " has_warning = True
f"This argument will be overridden by the one found in trainer_state.json within the checkpoint directory."
)
if warnings_list: # train bs is special as we need to account for multi-GPU
for warning in warnings_list: train_bs_args = training_args.per_device_train_batch_size
logger.warning(warning) train_bs_state = trainer_state.train_batch_size // max(1, training_args.n_gpu)
if train_bs_args != train_bs_state:
warning_str += f"\n\tper_device_train_batch_size: {train_bs_args} (from args) != {train_bs_state} (from trainer_state.json)"
has_warning = True
if has_warning:
logger.warning_once(warning_str)
def _wrap_model(self, model, training=True, dataloader=None): def _wrap_model(self, model, training=True, dataloader=None):
if self.args.use_ipex: if self.args.use_ipex:

View File

@@ -2540,16 +2540,14 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
) )
checkpoint_trainer.train(resume_from_checkpoint=checkpoint) checkpoint_trainer.train(resume_from_checkpoint=checkpoint)
self.assertIn("save_steps: 10 (from args) != 5 (from trainer_state.json)", cl.out)
self.assertIn( self.assertIn(
"Warning: The training argument 'save_steps' value (10) does not match the trainer state 'save_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.", "per_device_train_batch_size: 8 (from args) != 4 (from trainer_state.json)",
cl.out, cl.out,
) )
self.assertIn( self.assertIn(
"Warning: The training argument 'per_device_train_batch_size' value (8) does not match the trainer state 'train_batch_size' value (4). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.", "eval_steps: 10 (from args) != 5 (from trainer_state.json)",
cl.out,
)
self.assertIn(
"Warning: The training argument 'eval_steps' value (10) does not match the trainer state 'eval_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
cl.out, cl.out,
) )