fix wandb hp search unable to resume from sweep_id (#35883)
* fix wandb hp search unable to resume from sweep_id * format styles --------- Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -585,11 +585,19 @@ def run_hp_search_wandb(trainer, n_trials: int, direction: str, **kwargs) -> Bes
|
||||
|
||||
return trainer.objective
|
||||
|
||||
sweep_id = wandb.sweep(sweep_config, project=project, entity=entity) if not sweep_id else sweep_id
|
||||
if not sweep_id:
|
||||
sweep_id = wandb.sweep(sweep_config, project=project, entity=entity)
|
||||
else:
|
||||
import wandb.env
|
||||
|
||||
if entity:
|
||||
wandb.env.set_entity(entity)
|
||||
wandb.env.set_project(project)
|
||||
|
||||
logger.info(f"wandb sweep id - {sweep_id}")
|
||||
wandb.agent(sweep_id, function=_objective, count=n_trials)
|
||||
|
||||
return BestRun(best_trial["run_id"], best_trial["objective"], best_trial["hyperparameters"])
|
||||
return BestRun(best_trial["run_id"], best_trial["objective"], best_trial["hyperparameters"], sweep_id)
|
||||
|
||||
|
||||
def get_available_reporting_integrations():
|
||||
|
||||
@@ -5707,9 +5707,6 @@ class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase):
|
||||
self.batch_size = args.train_batch_size
|
||||
|
||||
def test_hyperparameter_search(self):
|
||||
class MyTrialShortNamer(TrialShortNamer):
|
||||
DEFAULTS = {"a": 0, "b": 0}
|
||||
|
||||
def hp_space(trial):
|
||||
return {
|
||||
"method": "random",
|
||||
@@ -5731,9 +5728,6 @@ class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase):
|
||||
|
||||
return RegressionPreTrainedModel(model_config)
|
||||
|
||||
def hp_name(params):
|
||||
return MyTrialShortNamer.shortname(params)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=tmp_dir,
|
||||
@@ -5748,9 +5742,31 @@ class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase):
|
||||
run_name="test",
|
||||
model_init=model_init,
|
||||
)
|
||||
trainer.hyperparameter_search(
|
||||
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="wandb", n_trials=4, anonymous="must"
|
||||
)
|
||||
sweep_kwargs = {
|
||||
"direction": "minimize",
|
||||
"hp_space": hp_space,
|
||||
"backend": "wandb",
|
||||
"n_trials": 4,
|
||||
}
|
||||
best_run = trainer.hyperparameter_search(**sweep_kwargs)
|
||||
|
||||
self.assertIsNotNone(best_run.run_id)
|
||||
self.assertIsNotNone(best_run.run_summary)
|
||||
hp_keys = set(best_run.hyperparameters.keys())
|
||||
self.assertSetEqual(hp_keys, {"a", "b", "assignments", "metric"})
|
||||
|
||||
# pretend restarting the process purged the environ
|
||||
import os
|
||||
|
||||
del os.environ["WANDB_ENTITY"]
|
||||
del os.environ["WANDB_PROJECT"]
|
||||
sweep_kwargs["sweep_id"] = best_run.run_summary
|
||||
updated_best_run = trainer.hyperparameter_search(**sweep_kwargs)
|
||||
|
||||
self.assertIsNotNone(updated_best_run.run_id)
|
||||
self.assertEqual(updated_best_run.run_summary, best_run.run_summary)
|
||||
updated_hp_keys = set(updated_best_run.hyperparameters.keys())
|
||||
self.assertSetEqual(updated_hp_keys, {"a", "b", "assignments", "metric"})
|
||||
|
||||
|
||||
class HyperParameterSearchBackendsTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user