From 082834dd79d14a2e53332c0c1a7c19852ab8973f Mon Sep 17 00:00:00 2001 From: Manny Cortes Date: Wed, 26 Feb 2025 08:06:48 -0800 Subject: [PATCH] fix: prevent model access error during Optuna hyperparameter tuning (#36395) * fix: prevent model access error during Optuna hyperparameter tuning The `transformers.integrations.integration_utils.run_hp_search_optuna` function releases model memory and sets trainer.model to None after each trial. This causes an AttributeError when subsequent Trainer.train calls attempt to access the model before reinitialization. This is only an issue when `fp16_full_eval` or `bf16_full_eval` flags are enabled. * Update src/transformers/trainer.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/trainer.py | 7 ++++++- tests/trainer/test_trainer.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 791b698860..b6dffe8d85 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2180,7 +2180,12 @@ class Trainer: # do_train is not a reliable argument, as it might not be set and .train() still called, so # the following is a workaround: - if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train and not self.is_model_parallel: + if ( + (args.fp16_full_eval or args.bf16_full_eval) + and not args.do_train + and not self.is_model_parallel + and self.model_init is None + ): self._move_model_to_device(self.model, args.device) if "model_path" in kwargs: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5e910bc79b..0a78a524f4 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -4998,6 +4998,38 @@ class TrainerHyperParameterMultiObjectOptunaIntegrationTest(unittest.TestCase): ) +@require_torch +@require_optuna +class TrainerHyperParameterOptunaIntegrationTestWithFullEval(unittest.TestCase): + def test_hyperparameter_search(self): + def hp_space(trial): + return {} + + def model_init(trial): + if trial is not None: + a = trial.suggest_int("a", -4, 4) + b = trial.suggest_int("b", -4, 4) + else: + a = 0 + b = 0 + config = RegressionModelConfig(a=a, b=b, double_output=False) + + return RegressionPreTrainedModel(config) + + with tempfile.TemporaryDirectory() as tmp_dir: + trainer = get_regression_trainer( + output_dir=tmp_dir, + disable_tqdm=True, + model_init=model_init, + fp16_full_eval=True, + ) + trainer.hyperparameter_search( + direction="minimize", + hp_space=hp_space, + n_trials=2, + ) + + @require_torch @require_ray class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):