Refactor hyperparameter search backends (#24384)
* Refactor hyperparameter search backends * Simpler refactoring without abstract base class * black * review comments: specify name in class use methods instead of callable class attributes name constant better * review comments: safer bool checking, log multiple available backends * test ALL_HYPERPARAMETER_SEARCH_BACKENDS vs HPSearchBackend in unit test, not module. format with black. * copyright
This commit is contained in:
@@ -98,6 +98,7 @@ _import_structure = {
|
||||
"file_utils": [],
|
||||
"generation": ["GenerationConfig", "TextIteratorStreamer", "TextStreamer"],
|
||||
"hf_argparser": ["HfArgumentParser"],
|
||||
"hyperparameter_search": [],
|
||||
"image_transforms": [],
|
||||
"integrations": [
|
||||
"is_clearml_available",
|
||||
|
||||
136
src/transformers/hyperparameter_search.py
Normal file
136
src/transformers/hyperparameter_search.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .integrations import (
|
||||
is_optuna_available,
|
||||
is_ray_available,
|
||||
is_sigopt_available,
|
||||
is_wandb_available,
|
||||
run_hp_search_optuna,
|
||||
run_hp_search_ray,
|
||||
run_hp_search_sigopt,
|
||||
run_hp_search_wandb,
|
||||
)
|
||||
from .trainer_utils import (
|
||||
HPSearchBackend,
|
||||
default_hp_space_optuna,
|
||||
default_hp_space_ray,
|
||||
default_hp_space_sigopt,
|
||||
default_hp_space_wandb,
|
||||
)
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class HyperParamSearchBackendBase:
|
||||
name: str
|
||||
pip_package: str = None
|
||||
|
||||
def is_available(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def default_hp_space(self, trial):
|
||||
raise NotImplementedError
|
||||
|
||||
def ensure_available(self):
|
||||
if not self.is_available():
|
||||
raise RuntimeError(
|
||||
f"You picked the {self.name} backend, but it is not installed. Run {self.pip_install()}."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def pip_install(cls):
|
||||
return f"`pip install {cls.pip_package or cls.name}`"
|
||||
|
||||
|
||||
class OptunaBackend(HyperParamSearchBackendBase):
|
||||
name = "optuna"
|
||||
|
||||
def is_available(self):
|
||||
return is_optuna_available()
|
||||
|
||||
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
||||
return run_hp_search_optuna(trainer, n_trials, direction, **kwargs)
|
||||
|
||||
def default_hp_space(self, trial):
|
||||
return default_hp_space_optuna(trial)
|
||||
|
||||
|
||||
class RayTuneBackend(HyperParamSearchBackendBase):
|
||||
name = "ray"
|
||||
pip_package = "'ray[tune]'"
|
||||
|
||||
def is_available(self):
|
||||
return is_ray_available()
|
||||
|
||||
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
||||
return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
|
||||
|
||||
def default_hp_space(self, trial):
|
||||
return default_hp_space_ray(trial)
|
||||
|
||||
|
||||
class SigOptBackend(HyperParamSearchBackendBase):
|
||||
name = "sigopt"
|
||||
|
||||
def is_available(self):
|
||||
return is_sigopt_available()
|
||||
|
||||
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
||||
return run_hp_search_sigopt(trainer, n_trials, direction, **kwargs)
|
||||
|
||||
def default_hp_space(self, trial):
|
||||
return default_hp_space_sigopt(trial)
|
||||
|
||||
|
||||
class WandbBackend(HyperParamSearchBackendBase):
|
||||
name = "wandb"
|
||||
|
||||
def is_available(self):
|
||||
return is_wandb_available()
|
||||
|
||||
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
||||
return run_hp_search_wandb(trainer, n_trials, direction, **kwargs)
|
||||
|
||||
def default_hp_space(self, trial):
|
||||
return default_hp_space_wandb(trial)
|
||||
|
||||
|
||||
ALL_HYPERPARAMETER_SEARCH_BACKENDS = {
|
||||
HPSearchBackend(backend.name): backend for backend in [OptunaBackend, RayTuneBackend, SigOptBackend, WandbBackend]
|
||||
}
|
||||
|
||||
|
||||
def default_hp_search_backend() -> str:
|
||||
available_backends = [backend for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values() if backend.is_available()]
|
||||
if len(available_backends) > 0:
|
||||
name = available_backends[0].name
|
||||
if len(available_backends) > 1:
|
||||
logger.info(
|
||||
f"{len(available_backends)} hyperparameter search backends available. Using {name} as the default."
|
||||
)
|
||||
return name
|
||||
raise RuntimeError(
|
||||
"No hyperparameter search backend available.\n"
|
||||
+ "\n".join(
|
||||
f" - To install {backend.name} run {backend.pip_install()}"
|
||||
for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values()
|
||||
)
|
||||
)
|
||||
@@ -177,15 +177,6 @@ def hp_params(trial):
|
||||
raise RuntimeError(f"Unknown type for trial {trial.__class__}")
|
||||
|
||||
|
||||
def default_hp_search_backend():
|
||||
if is_optuna_available():
|
||||
return "optuna"
|
||||
elif is_ray_tune_available():
|
||||
return "ray"
|
||||
elif is_sigopt_available():
|
||||
return "sigopt"
|
||||
|
||||
|
||||
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
|
||||
import optuna
|
||||
|
||||
|
||||
@@ -36,18 +36,9 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
|
||||
# Integrations must be imported before ML frameworks:
|
||||
# isort: off
|
||||
from .integrations import (
|
||||
default_hp_search_backend,
|
||||
get_reporting_integration_callbacks,
|
||||
hp_params,
|
||||
is_fairscale_available,
|
||||
is_optuna_available,
|
||||
is_ray_tune_available,
|
||||
is_sigopt_available,
|
||||
is_wandb_available,
|
||||
run_hp_search_optuna,
|
||||
run_hp_search_ray,
|
||||
run_hp_search_sigopt,
|
||||
run_hp_search_wandb,
|
||||
)
|
||||
|
||||
# isort: on
|
||||
@@ -66,6 +57,7 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_d
|
||||
from .debug_utils import DebugOption, DebugUnderflowOverflow
|
||||
from .deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled
|
||||
from .dependency_versions_check import dep_version_check
|
||||
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
|
||||
from .modelcard import TrainingSummary
|
||||
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
|
||||
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
|
||||
@@ -114,7 +106,6 @@ from .trainer_utils import (
|
||||
TrainerMemoryTracker,
|
||||
TrainOutput,
|
||||
default_compute_objective,
|
||||
default_hp_space,
|
||||
denumpify_detensorize,
|
||||
enable_full_determinism,
|
||||
find_executable_batch_size,
|
||||
@@ -2517,41 +2508,20 @@ class Trainer:
|
||||
"""
|
||||
if backend is None:
|
||||
backend = default_hp_search_backend()
|
||||
if backend is None:
|
||||
raise RuntimeError(
|
||||
"At least one of optuna or ray should be installed. "
|
||||
"To install optuna run `pip install optuna`. "
|
||||
"To install ray run `pip install ray[tune]`. "
|
||||
"To install sigopt run `pip install sigopt`."
|
||||
)
|
||||
backend = HPSearchBackend(backend)
|
||||
if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
|
||||
raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
|
||||
if backend == HPSearchBackend.RAY and not is_ray_tune_available():
|
||||
raise RuntimeError(
|
||||
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
|
||||
)
|
||||
if backend == HPSearchBackend.SIGOPT and not is_sigopt_available():
|
||||
raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.")
|
||||
if backend == HPSearchBackend.WANDB and not is_wandb_available():
|
||||
raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.")
|
||||
backend_obj = ALL_HYPERPARAMETER_SEARCH_BACKENDS[backend]()
|
||||
backend_obj.ensure_available()
|
||||
self.hp_search_backend = backend
|
||||
if self.model_init is None:
|
||||
raise RuntimeError(
|
||||
"To use hyperparameter search, you need to pass your model through a model_init function."
|
||||
)
|
||||
|
||||
self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
|
||||
self.hp_space = backend_obj.default_hp_space if hp_space is None else hp_space
|
||||
self.hp_name = hp_name
|
||||
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
|
||||
|
||||
backend_dict = {
|
||||
HPSearchBackend.OPTUNA: run_hp_search_optuna,
|
||||
HPSearchBackend.RAY: run_hp_search_ray,
|
||||
HPSearchBackend.SIGOPT: run_hp_search_sigopt,
|
||||
HPSearchBackend.WANDB: run_hp_search_wandb,
|
||||
}
|
||||
best_run = backend_dict[backend](self, n_trials, direction, **kwargs)
|
||||
best_run = backend_obj.run(self, n_trials, direction, **kwargs)
|
||||
|
||||
self.hp_search_backend = None
|
||||
return best_run
|
||||
|
||||
@@ -301,14 +301,6 @@ class HPSearchBackend(ExplicitEnum):
|
||||
WANDB = "wandb"
|
||||
|
||||
|
||||
default_hp_space = {
|
||||
HPSearchBackend.OPTUNA: default_hp_space_optuna,
|
||||
HPSearchBackend.RAY: default_hp_space_ray,
|
||||
HPSearchBackend.SIGOPT: default_hp_space_sigopt,
|
||||
HPSearchBackend.WANDB: default_hp_space_wandb,
|
||||
}
|
||||
|
||||
|
||||
def is_main_process(local_rank):
|
||||
"""
|
||||
Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
|
||||
|
||||
@@ -42,6 +42,7 @@ from transformers import (
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS
|
||||
from transformers.testing_utils import (
|
||||
ENDPOINT_STAGING,
|
||||
TOKEN,
|
||||
@@ -72,7 +73,7 @@ from transformers.testing_utils import (
|
||||
require_wandb,
|
||||
slow,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
@@ -2803,3 +2804,11 @@ class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase):
|
||||
trainer.hyperparameter_search(
|
||||
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="wandb", n_trials=4, anonymous="must"
|
||||
)
|
||||
|
||||
|
||||
class HyperParameterSearchBackendsTest(unittest.TestCase):
|
||||
def test_hyperparameter_search_backends(self):
|
||||
self.assertEqual(
|
||||
list(ALL_HYPERPARAMETER_SEARCH_BACKENDS.keys()),
|
||||
list(HPSearchBackend),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user