fix wandb hp search unable to resume from sweep_id (#35883)
* fix wandb hp search unable to resume from sweep_id * format styles --------- Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -585,11 +585,19 @@ def run_hp_search_wandb(trainer, n_trials: int, direction: str, **kwargs) -> Bes
|
|||||||
|
|
||||||
return trainer.objective
|
return trainer.objective
|
||||||
|
|
||||||
sweep_id = wandb.sweep(sweep_config, project=project, entity=entity) if not sweep_id else sweep_id
|
if not sweep_id:
|
||||||
|
sweep_id = wandb.sweep(sweep_config, project=project, entity=entity)
|
||||||
|
else:
|
||||||
|
import wandb.env
|
||||||
|
|
||||||
|
if entity:
|
||||||
|
wandb.env.set_entity(entity)
|
||||||
|
wandb.env.set_project(project)
|
||||||
|
|
||||||
logger.info(f"wandb sweep id - {sweep_id}")
|
logger.info(f"wandb sweep id - {sweep_id}")
|
||||||
wandb.agent(sweep_id, function=_objective, count=n_trials)
|
wandb.agent(sweep_id, function=_objective, count=n_trials)
|
||||||
|
|
||||||
return BestRun(best_trial["run_id"], best_trial["objective"], best_trial["hyperparameters"])
|
return BestRun(best_trial["run_id"], best_trial["objective"], best_trial["hyperparameters"], sweep_id)
|
||||||
|
|
||||||
|
|
||||||
def get_available_reporting_integrations():
|
def get_available_reporting_integrations():
|
||||||
|
|||||||
@@ -5707,9 +5707,6 @@ class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase):
|
|||||||
self.batch_size = args.train_batch_size
|
self.batch_size = args.train_batch_size
|
||||||
|
|
||||||
def test_hyperparameter_search(self):
|
def test_hyperparameter_search(self):
|
||||||
class MyTrialShortNamer(TrialShortNamer):
|
|
||||||
DEFAULTS = {"a": 0, "b": 0}
|
|
||||||
|
|
||||||
def hp_space(trial):
|
def hp_space(trial):
|
||||||
return {
|
return {
|
||||||
"method": "random",
|
"method": "random",
|
||||||
@@ -5731,9 +5728,6 @@ class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
return RegressionPreTrainedModel(model_config)
|
return RegressionPreTrainedModel(model_config)
|
||||||
|
|
||||||
def hp_name(params):
|
|
||||||
return MyTrialShortNamer.shortname(params)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
trainer = get_regression_trainer(
|
trainer = get_regression_trainer(
|
||||||
output_dir=tmp_dir,
|
output_dir=tmp_dir,
|
||||||
@@ -5748,9 +5742,31 @@ class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase):
|
|||||||
run_name="test",
|
run_name="test",
|
||||||
model_init=model_init,
|
model_init=model_init,
|
||||||
)
|
)
|
||||||
trainer.hyperparameter_search(
|
sweep_kwargs = {
|
||||||
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="wandb", n_trials=4, anonymous="must"
|
"direction": "minimize",
|
||||||
)
|
"hp_space": hp_space,
|
||||||
|
"backend": "wandb",
|
||||||
|
"n_trials": 4,
|
||||||
|
}
|
||||||
|
best_run = trainer.hyperparameter_search(**sweep_kwargs)
|
||||||
|
|
||||||
|
self.assertIsNotNone(best_run.run_id)
|
||||||
|
self.assertIsNotNone(best_run.run_summary)
|
||||||
|
hp_keys = set(best_run.hyperparameters.keys())
|
||||||
|
self.assertSetEqual(hp_keys, {"a", "b", "assignments", "metric"})
|
||||||
|
|
||||||
|
# pretend restarting the process purged the environ
|
||||||
|
import os
|
||||||
|
|
||||||
|
del os.environ["WANDB_ENTITY"]
|
||||||
|
del os.environ["WANDB_PROJECT"]
|
||||||
|
sweep_kwargs["sweep_id"] = best_run.run_summary
|
||||||
|
updated_best_run = trainer.hyperparameter_search(**sweep_kwargs)
|
||||||
|
|
||||||
|
self.assertIsNotNone(updated_best_run.run_id)
|
||||||
|
self.assertEqual(updated_best_run.run_summary, best_run.run_summary)
|
||||||
|
updated_hp_keys = set(updated_best_run.hyperparameters.keys())
|
||||||
|
self.assertSetEqual(updated_hp_keys, {"a", "b", "assignments", "metric"})
|
||||||
|
|
||||||
|
|
||||||
class HyperParameterSearchBackendsTest(unittest.TestCase):
|
class HyperParameterSearchBackendsTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user