Refactor hyperparameter search backends (#24384)

* Refactor hyperparameter search backends

* Simpler refactoring without abstract base class

* black

* review comments:
specify name in class
use methods instead of callable class attributes
name constant better

* review comments: safer bool checking, log multiple available backends

* test ALL_HYPERPARAMETER_SEARCH_BACKENDS vs HPSearchBackend in unit test, not module. format with black.

* copyright
This commit is contained in:
Alex Hall
2023-06-22 20:28:25 +02:00
committed by GitHub
parent a1c4b63076
commit b6295b26c5
6 changed files with 152 additions and 53 deletions

View File

@@ -42,6 +42,7 @@ from transformers import (
is_torch_available,
logging,
)
from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS
from transformers.testing_utils import (
ENDPOINT_STAGING,
TOKEN,
@@ -72,7 +73,7 @@ from transformers.testing_utils import (
require_wandb,
slow,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend
from transformers.training_args import OptimizerNames
from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
@@ -2803,3 +2804,11 @@ class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase):
trainer.hyperparameter_search(
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="wandb", n_trials=4, anonymous="must"
)
class HyperParameterSearchBackendsTest(unittest.TestCase):
def test_hyperparameter_search_backends(self):
self.assertEqual(
list(ALL_HYPERPARAMETER_SEARCH_BACKENDS.keys()),
list(HPSearchBackend),
)