fix wandb hp search unable to resume from sweep_id (#35883)

* fix wandb hp search unable to resume from sweep_id

* format styles

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
bd793fcb
2025-03-14 00:32:26 +13:00
committed by GitHub
parent 47cc4da351
commit 87b30c3589
2 changed files with 35 additions and 11 deletions

View File

@@ -5707,9 +5707,6 @@ class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase):
self.batch_size = args.train_batch_size
def test_hyperparameter_search(self):
class MyTrialShortNamer(TrialShortNamer):
DEFAULTS = {"a": 0, "b": 0}
def hp_space(trial):
return {
"method": "random",
@@ -5731,9 +5728,6 @@ class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase):
return RegressionPreTrainedModel(model_config)
def hp_name(params):
return MyTrialShortNamer.shortname(params)
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(
output_dir=tmp_dir,
@@ -5748,9 +5742,31 @@ class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase):
run_name="test",
model_init=model_init,
)
trainer.hyperparameter_search(
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="wandb", n_trials=4, anonymous="must"
)
sweep_kwargs = {
"direction": "minimize",
"hp_space": hp_space,
"backend": "wandb",
"n_trials": 4,
}
best_run = trainer.hyperparameter_search(**sweep_kwargs)
self.assertIsNotNone(best_run.run_id)
self.assertIsNotNone(best_run.run_summary)
hp_keys = set(best_run.hyperparameters.keys())
self.assertSetEqual(hp_keys, {"a", "b", "assignments", "metric"})
# pretend restarting the process purged the environ
import os
del os.environ["WANDB_ENTITY"]
del os.environ["WANDB_PROJECT"]
sweep_kwargs["sweep_id"] = best_run.run_summary
updated_best_run = trainer.hyperparameter_search(**sweep_kwargs)
self.assertIsNotNone(updated_best_run.run_id)
self.assertEqual(updated_best_run.run_summary, best_run.run_summary)
updated_hp_keys = set(updated_best_run.hyperparameters.keys())
self.assertSetEqual(updated_hp_keys, {"a", "b", "assignments", "metric"})
class HyperParameterSearchBackendsTest(unittest.TestCase):