fix wandb hp search unable to resume from sweep_id (#35883)

* fix wandb hp search unable to resume from sweep_id

* format styles

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
bd793fcb
2025-03-14 00:32:26 +13:00
committed by GitHub
parent 47cc4da351
commit 87b30c3589
2 changed files with 35 additions and 11 deletions

View File

@@ -585,11 +585,19 @@ def run_hp_search_wandb(trainer, n_trials: int, direction: str, **kwargs) -> Bes
return trainer.objective
sweep_id = wandb.sweep(sweep_config, project=project, entity=entity) if not sweep_id else sweep_id
if not sweep_id:
sweep_id = wandb.sweep(sweep_config, project=project, entity=entity)
else:
import wandb.env
if entity:
wandb.env.set_entity(entity)
wandb.env.set_project(project)
logger.info(f"wandb sweep id - {sweep_id}")
wandb.agent(sweep_id, function=_objective, count=n_trials)
return BestRun(best_trial["run_id"], best_trial["objective"], best_trial["hyperparameters"])
return BestRun(best_trial["run_id"], best_trial["objective"], best_trial["hyperparameters"], sweep_id)
def get_available_reporting_integrations():

View File

@@ -5707,9 +5707,6 @@ class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase):
self.batch_size = args.train_batch_size
def test_hyperparameter_search(self):
class MyTrialShortNamer(TrialShortNamer):
DEFAULTS = {"a": 0, "b": 0}
def hp_space(trial):
return {
"method": "random",
@@ -5731,9 +5728,6 @@ class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase):
return RegressionPreTrainedModel(model_config)
def hp_name(params):
return MyTrialShortNamer.shortname(params)
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(
output_dir=tmp_dir,
@@ -5748,9 +5742,31 @@ class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase):
run_name="test",
model_init=model_init,
)
trainer.hyperparameter_search(
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="wandb", n_trials=4, anonymous="must"
)
sweep_kwargs = {
"direction": "minimize",
"hp_space": hp_space,
"backend": "wandb",
"n_trials": 4,
}
best_run = trainer.hyperparameter_search(**sweep_kwargs)
self.assertIsNotNone(best_run.run_id)
self.assertIsNotNone(best_run.run_summary)
hp_keys = set(best_run.hyperparameters.keys())
self.assertSetEqual(hp_keys, {"a", "b", "assignments", "metric"})
# pretend restarting the process purged the environ
import os
del os.environ["WANDB_ENTITY"]
del os.environ["WANDB_PROJECT"]
sweep_kwargs["sweep_id"] = best_run.run_summary
updated_best_run = trainer.hyperparameter_search(**sweep_kwargs)
self.assertIsNotNone(updated_best_run.run_id)
self.assertEqual(updated_best_run.run_summary, best_run.run_summary)
updated_hp_keys = set(updated_best_run.hyperparameters.keys())
self.assertSetEqual(updated_hp_keys, {"a", "b", "assignments", "metric"})
class HyperParameterSearchBackendsTest(unittest.TestCase):