Rework tests to compare trainer checkpoint args (#29883)
* Start rework * Fix failing test * Include max * Update src/transformers/trainer.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -1565,23 +1565,28 @@ class Trainer:
|
||||
"logging_steps": "logging_steps",
|
||||
"eval_steps": "eval_steps",
|
||||
"save_steps": "save_steps",
|
||||
"per_device_train_batch_size": "train_batch_size",
|
||||
}
|
||||
|
||||
warnings_list = []
|
||||
has_warning = False
|
||||
warning_str = "Warning: The following arguments do not match the ones in the `trainer_state.json` within the checkpoint directory: "
|
||||
for arg_attr, state_attr in attributes_map.items():
|
||||
arg_value = getattr(training_args, arg_attr, None)
|
||||
state_value = getattr(trainer_state, state_attr, None)
|
||||
|
||||
if arg_value is not None and state_value is not None and arg_value != state_value:
|
||||
warnings_list.append(
|
||||
f"Warning: The training argument '{arg_attr}' value ({arg_value}) does not match the trainer state '{state_attr}' value ({state_value}). "
|
||||
f"This argument will be overridden by the one found in trainer_state.json within the checkpoint directory."
|
||||
)
|
||||
warning_str += f"\n\t{arg_attr}: {arg_value} (from args) != {state_value} (from trainer_state.json)"
|
||||
has_warning = True
|
||||
|
||||
if warnings_list:
|
||||
for warning in warnings_list:
|
||||
logger.warning(warning)
|
||||
# train bs is special as we need to account for multi-GPU
|
||||
train_bs_args = training_args.per_device_train_batch_size
|
||||
train_bs_state = trainer_state.train_batch_size // max(1, training_args.n_gpu)
|
||||
|
||||
if train_bs_args != train_bs_state:
|
||||
warning_str += f"\n\tper_device_train_batch_size: {train_bs_args} (from args) != {train_bs_state} (from trainer_state.json)"
|
||||
has_warning = True
|
||||
|
||||
if has_warning:
|
||||
logger.warning_once(warning_str)
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.use_ipex:
|
||||
|
||||
@@ -2540,16 +2540,14 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
)
|
||||
checkpoint_trainer.train(resume_from_checkpoint=checkpoint)
|
||||
|
||||
self.assertIn("save_steps: 10 (from args) != 5 (from trainer_state.json)", cl.out)
|
||||
|
||||
self.assertIn(
|
||||
"Warning: The training argument 'save_steps' value (10) does not match the trainer state 'save_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
|
||||
"per_device_train_batch_size: 8 (from args) != 4 (from trainer_state.json)",
|
||||
cl.out,
|
||||
)
|
||||
self.assertIn(
|
||||
"Warning: The training argument 'per_device_train_batch_size' value (8) does not match the trainer state 'train_batch_size' value (4). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
|
||||
cl.out,
|
||||
)
|
||||
self.assertIn(
|
||||
"Warning: The training argument 'eval_steps' value (10) does not match the trainer state 'eval_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
|
||||
"eval_steps: 10 (from args) != 5 (from trainer_state.json)",
|
||||
cl.out,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user