Add warnings if training args differ from checkpoint trainer state (#29255)
* add warnings if training args differ from checkpoint args stored in trainer_state.json * run formatting and styling * add a test * format and styling --------- Co-authored-by: Jonathan Flynn <jonl.flynn@guardian.co.uk>
This commit is contained in:
@@ -1529,6 +1529,29 @@ class Trainer:
|
||||
|
||||
return model
|
||||
|
||||
def compare_trainer_and_checkpoint_args(self, training_args, trainer_state):
|
||||
attributes_map = {
|
||||
"logging_steps": "logging_steps",
|
||||
"eval_steps": "eval_steps",
|
||||
"save_steps": "save_steps",
|
||||
"per_device_train_batch_size": "train_batch_size",
|
||||
}
|
||||
|
||||
warnings_list = []
|
||||
for arg_attr, state_attr in attributes_map.items():
|
||||
arg_value = getattr(training_args, arg_attr, None)
|
||||
state_value = getattr(trainer_state, state_attr, None)
|
||||
|
||||
if arg_value is not None and state_value is not None and arg_value != state_value:
|
||||
warnings_list.append(
|
||||
f"Warning: The training argument '{arg_attr}' value ({arg_value}) does not match the trainer state '{state_attr}' value ({state_value}). "
|
||||
f"This argument will be overridden by the one found in trainer_state.json within the checkpoint directory."
|
||||
)
|
||||
|
||||
if warnings_list:
|
||||
for warning in warnings_list:
|
||||
logger.warning(warning)
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.use_ipex:
|
||||
dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
|
||||
@@ -1991,6 +2014,7 @@ class Trainer:
|
||||
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
|
||||
):
|
||||
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
|
||||
self.compare_trainer_and_checkpoint_args(self.args, self.state)
|
||||
epochs_trained = self.state.global_step // num_update_steps_per_epoch
|
||||
if not args.ignore_data_skip:
|
||||
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
|
||||
|
||||
@@ -2485,6 +2485,46 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer.state.best_model_checkpoint = os.path.join(tmp_dir, "checkpoint-5")
|
||||
self.check_checkpoint_deletion(trainer, tmp_dir, [5, 25])
|
||||
|
||||
def test_compare_trainer_and_checkpoint_args_logging(self):
|
||||
logger = logging.get_logger()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir, CaptureLogger(logger) as cl:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmpdir,
|
||||
train_len=128,
|
||||
eval_steps=5,
|
||||
gradient_accumulation_steps=2,
|
||||
per_device_train_batch_size=4,
|
||||
save_steps=5,
|
||||
learning_rate=0.1,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
checkpoint = os.path.join(tmpdir, "checkpoint-5")
|
||||
checkpoint_trainer = get_regression_trainer(
|
||||
output_dir=tmpdir,
|
||||
train_len=256,
|
||||
eval_steps=10,
|
||||
gradient_accumulation_steps=4,
|
||||
per_device_train_batch_size=8,
|
||||
save_steps=10,
|
||||
learning_rate=0.1,
|
||||
)
|
||||
checkpoint_trainer.train(resume_from_checkpoint=checkpoint)
|
||||
|
||||
self.assertIn(
|
||||
"Warning: The training argument 'save_steps' value (10) does not match the trainer state 'save_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
|
||||
cl.out,
|
||||
)
|
||||
self.assertIn(
|
||||
"Warning: The training argument 'per_device_train_batch_size' value (8) does not match the trainer state 'train_batch_size' value (4). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
|
||||
cl.out,
|
||||
)
|
||||
self.assertIn(
|
||||
"Warning: The training argument 'eval_steps' value (10) does not match the trainer state 'eval_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.",
|
||||
cl.out,
|
||||
)
|
||||
|
||||
def check_mem_metrics(self, trainer, check_func):
|
||||
metrics = trainer.train().metrics
|
||||
check_func("init_mem_cpu_alloc_delta", metrics)
|
||||
|
||||
Reference in New Issue
Block a user