From b5a6d6eeab2524546776057b40774835c45c4971 Mon Sep 17 00:00:00 2001 From: Jonathan Flynn <91546670+jonflynng@users.noreply.github.com> Date: Tue, 26 Mar 2024 06:13:13 +0000 Subject: [PATCH] Add warnings if training args differ from checkpoint trainer state (#29255) * add warnings if training args differ from checkpoint args stored in trainer_state.json * run formatting and styling * add a test * format and styling --------- Co-authored-by: Jonathan Flynn --- src/transformers/trainer.py | 24 +++++++++++++++++++++ tests/trainer/test_trainer.py | 40 +++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 1bf69da039..6adef64757 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1529,6 +1529,29 @@ class Trainer: return model + def compare_trainer_and_checkpoint_args(self, training_args, trainer_state): + attributes_map = { + "logging_steps": "logging_steps", + "eval_steps": "eval_steps", + "save_steps": "save_steps", + "per_device_train_batch_size": "train_batch_size", + } + + warnings_list = [] + for arg_attr, state_attr in attributes_map.items(): + arg_value = getattr(training_args, arg_attr, None) + state_value = getattr(trainer_state, state_attr, None) + + if arg_value is not None and state_value is not None and arg_value != state_value: + warnings_list.append( + f"Warning: The training argument '{arg_attr}' value ({arg_value}) does not match the trainer state '{state_attr}' value ({state_value}). " + f"This argument will be overridden by the one found in trainer_state.json within the checkpoint directory." + ) + + if warnings_list: + for warning in warnings_list: + logger.warning(warning) + def _wrap_model(self, model, training=True, dataloader=None): if self.args.use_ipex: dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 @@ -1991,6 +2014,7 @@ class Trainer: os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) ): self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) + self.compare_trainer_and_checkpoint_args(self.args, self.state) epochs_trained = self.state.global_step // num_update_steps_per_epoch if not args.ignore_data_skip: steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index ebc628146b..926b7752c0 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -2485,6 +2485,46 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): trainer.state.best_model_checkpoint = os.path.join(tmp_dir, "checkpoint-5") self.check_checkpoint_deletion(trainer, tmp_dir, [5, 25]) + def test_compare_trainer_and_checkpoint_args_logging(self): + logger = logging.get_logger() + + with tempfile.TemporaryDirectory() as tmpdir, CaptureLogger(logger) as cl: + trainer = get_regression_trainer( + output_dir=tmpdir, + train_len=128, + eval_steps=5, + gradient_accumulation_steps=2, + per_device_train_batch_size=4, + save_steps=5, + learning_rate=0.1, + ) + trainer.train() + + checkpoint = os.path.join(tmpdir, "checkpoint-5") + checkpoint_trainer = get_regression_trainer( + output_dir=tmpdir, + train_len=256, + eval_steps=10, + gradient_accumulation_steps=4, + per_device_train_batch_size=8, + save_steps=10, + learning_rate=0.1, + ) + checkpoint_trainer.train(resume_from_checkpoint=checkpoint) + + self.assertIn( + "Warning: The training argument 'save_steps' value (10) does not match the trainer state 'save_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.", + cl.out, + ) + self.assertIn( + "Warning: The training argument 'per_device_train_batch_size' value (8) does not match the trainer state 'train_batch_size' value (4). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.", + cl.out, + ) + self.assertIn( + "Warning: The training argument 'eval_steps' value (10) does not match the trainer state 'eval_steps' value (5). This argument will be overridden by the one found in trainer_state.json within the checkpoint directory.", + cl.out, + ) + def check_mem_metrics(self, trainer, check_func): metrics = trainer.train().metrics check_func("init_mem_cpu_alloc_delta", metrics)