From 87b30c35892568f9b83d4e8d1233956b8e0cd96c Mon Sep 17 00:00:00 2001 From: bd793fcb <79692219+bd793fcb@users.noreply.github.com> Date: Fri, 14 Mar 2025 00:32:26 +1300 Subject: [PATCH] fix wandb hp search unable to resume from sweep_id (#35883) * fix wandb hp search unable to resume from sweep_id * format styles --------- Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- .../integrations/integration_utils.py | 12 +++++-- tests/trainer/test_trainer.py | 34 ++++++++++++++----- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index d5b6983c30..95668fefd7 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -585,11 +585,19 @@ def run_hp_search_wandb(trainer, n_trials: int, direction: str, **kwargs) -> Bes return trainer.objective - sweep_id = wandb.sweep(sweep_config, project=project, entity=entity) if not sweep_id else sweep_id + if not sweep_id: + sweep_id = wandb.sweep(sweep_config, project=project, entity=entity) + else: + import wandb.env + + if entity: + wandb.env.set_entity(entity) + wandb.env.set_project(project) + logger.info(f"wandb sweep id - {sweep_id}") wandb.agent(sweep_id, function=_objective, count=n_trials) - return BestRun(best_trial["run_id"], best_trial["objective"], best_trial["hyperparameters"]) + return BestRun(best_trial["run_id"], best_trial["objective"], best_trial["hyperparameters"], sweep_id) def get_available_reporting_integrations(): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 1f6441758c..c8d9f34ff5 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -5707,9 +5707,6 @@ class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase): self.batch_size = args.train_batch_size def test_hyperparameter_search(self): - class MyTrialShortNamer(TrialShortNamer): - DEFAULTS = {"a": 0, "b": 0} - def hp_space(trial): return { "method": "random", @@ -5731,9 +5728,6 @@ class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase): return RegressionPreTrainedModel(model_config) - def hp_name(params): - return MyTrialShortNamer.shortname(params) - with tempfile.TemporaryDirectory() as tmp_dir: trainer = get_regression_trainer( output_dir=tmp_dir, @@ -5748,9 +5742,31 @@ class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase): run_name="test", model_init=model_init, ) - trainer.hyperparameter_search( - direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="wandb", n_trials=4, anonymous="must" - ) + sweep_kwargs = { + "direction": "minimize", + "hp_space": hp_space, + "backend": "wandb", + "n_trials": 4, + } + best_run = trainer.hyperparameter_search(**sweep_kwargs) + + self.assertIsNotNone(best_run.run_id) + self.assertIsNotNone(best_run.run_summary) + hp_keys = set(best_run.hyperparameters.keys()) + self.assertSetEqual(hp_keys, {"a", "b", "assignments", "metric"}) + + # pretend restarting the process purged the environ + import os + + del os.environ["WANDB_ENTITY"] + del os.environ["WANDB_PROJECT"] + sweep_kwargs["sweep_id"] = best_run.run_summary + updated_best_run = trainer.hyperparameter_search(**sweep_kwargs) + + self.assertIsNotNone(updated_best_run.run_id) + self.assertEqual(updated_best_run.run_summary, best_run.run_summary) + updated_hp_keys = set(updated_best_run.hyperparameters.keys()) + self.assertSetEqual(updated_hp_keys, {"a", "b", "assignments", "metric"}) class HyperParameterSearchBackendsTest(unittest.TestCase):