Refactor hyperparameter search backends (#24384)
* Refactor hyperparameter search backends * Simpler refactoring without abstract base class * black * review comments: specify name in class use methods instead of callable class attributes name constant better * review comments: safer bool checking, log multiple available backends * test ALL_HYPERPARAMETER_SEARCH_BACKENDS vs HPSearchBackend in unit test, not module. format with black. * copyright
This commit is contained in:
@@ -42,6 +42,7 @@ from transformers import (
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS
|
||||
from transformers.testing_utils import (
|
||||
ENDPOINT_STAGING,
|
||||
TOKEN,
|
||||
@@ -72,7 +73,7 @@ from transformers.testing_utils import (
|
||||
require_wandb,
|
||||
slow,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
@@ -2803,3 +2804,11 @@ class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase):
|
||||
trainer.hyperparameter_search(
|
||||
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="wandb", n_trials=4, anonymous="must"
|
||||
)
|
||||
|
||||
|
||||
class HyperParameterSearchBackendsTest(unittest.TestCase):
|
||||
def test_hyperparameter_search_backends(self):
|
||||
self.assertEqual(
|
||||
list(ALL_HYPERPARAMETER_SEARCH_BACKENDS.keys()),
|
||||
list(HPSearchBackend),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user