Refactor hyperparameter search backends (#24384)
* Refactor hyperparameter search backends * Simpler refactoring without abstract base class * black * review comments: specify name in class use methods instead of callable class attributes name constant better * review comments: safer bool checking, log multiple available backends * test ALL_HYPERPARAMETER_SEARCH_BACKENDS vs HPSearchBackend in unit test, not module. format with black. * copyright
This commit is contained in:
@@ -98,6 +98,7 @@ _import_structure = {
|
|||||||
"file_utils": [],
|
"file_utils": [],
|
||||||
"generation": ["GenerationConfig", "TextIteratorStreamer", "TextStreamer"],
|
"generation": ["GenerationConfig", "TextIteratorStreamer", "TextStreamer"],
|
||||||
"hf_argparser": ["HfArgumentParser"],
|
"hf_argparser": ["HfArgumentParser"],
|
||||||
|
"hyperparameter_search": [],
|
||||||
"image_transforms": [],
|
"image_transforms": [],
|
||||||
"integrations": [
|
"integrations": [
|
||||||
"is_clearml_available",
|
"is_clearml_available",
|
||||||
|
|||||||
136
src/transformers/hyperparameter_search.py
Normal file
136
src/transformers/hyperparameter_search.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023-present the HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .integrations import (
|
||||||
|
is_optuna_available,
|
||||||
|
is_ray_available,
|
||||||
|
is_sigopt_available,
|
||||||
|
is_wandb_available,
|
||||||
|
run_hp_search_optuna,
|
||||||
|
run_hp_search_ray,
|
||||||
|
run_hp_search_sigopt,
|
||||||
|
run_hp_search_wandb,
|
||||||
|
)
|
||||||
|
from .trainer_utils import (
|
||||||
|
HPSearchBackend,
|
||||||
|
default_hp_space_optuna,
|
||||||
|
default_hp_space_ray,
|
||||||
|
default_hp_space_sigopt,
|
||||||
|
default_hp_space_wandb,
|
||||||
|
)
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HyperParamSearchBackendBase:
|
||||||
|
name: str
|
||||||
|
pip_package: str = None
|
||||||
|
|
||||||
|
def is_available(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def default_hp_space(self, trial):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def ensure_available(self):
|
||||||
|
if not self.is_available():
|
||||||
|
raise RuntimeError(
|
||||||
|
f"You picked the {self.name} backend, but it is not installed. Run {self.pip_install()}."
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pip_install(cls):
|
||||||
|
return f"`pip install {cls.pip_package or cls.name}`"
|
||||||
|
|
||||||
|
|
||||||
|
class OptunaBackend(HyperParamSearchBackendBase):
|
||||||
|
name = "optuna"
|
||||||
|
|
||||||
|
def is_available(self):
|
||||||
|
return is_optuna_available()
|
||||||
|
|
||||||
|
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
||||||
|
return run_hp_search_optuna(trainer, n_trials, direction, **kwargs)
|
||||||
|
|
||||||
|
def default_hp_space(self, trial):
|
||||||
|
return default_hp_space_optuna(trial)
|
||||||
|
|
||||||
|
|
||||||
|
class RayTuneBackend(HyperParamSearchBackendBase):
|
||||||
|
name = "ray"
|
||||||
|
pip_package = "'ray[tune]'"
|
||||||
|
|
||||||
|
def is_available(self):
|
||||||
|
return is_ray_available()
|
||||||
|
|
||||||
|
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
||||||
|
return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
|
||||||
|
|
||||||
|
def default_hp_space(self, trial):
|
||||||
|
return default_hp_space_ray(trial)
|
||||||
|
|
||||||
|
|
||||||
|
class SigOptBackend(HyperParamSearchBackendBase):
|
||||||
|
name = "sigopt"
|
||||||
|
|
||||||
|
def is_available(self):
|
||||||
|
return is_sigopt_available()
|
||||||
|
|
||||||
|
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
||||||
|
return run_hp_search_sigopt(trainer, n_trials, direction, **kwargs)
|
||||||
|
|
||||||
|
def default_hp_space(self, trial):
|
||||||
|
return default_hp_space_sigopt(trial)
|
||||||
|
|
||||||
|
|
||||||
|
class WandbBackend(HyperParamSearchBackendBase):
|
||||||
|
name = "wandb"
|
||||||
|
|
||||||
|
def is_available(self):
|
||||||
|
return is_wandb_available()
|
||||||
|
|
||||||
|
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
||||||
|
return run_hp_search_wandb(trainer, n_trials, direction, **kwargs)
|
||||||
|
|
||||||
|
def default_hp_space(self, trial):
|
||||||
|
return default_hp_space_wandb(trial)
|
||||||
|
|
||||||
|
|
||||||
|
ALL_HYPERPARAMETER_SEARCH_BACKENDS = {
|
||||||
|
HPSearchBackend(backend.name): backend for backend in [OptunaBackend, RayTuneBackend, SigOptBackend, WandbBackend]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def default_hp_search_backend() -> str:
|
||||||
|
available_backends = [backend for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values() if backend.is_available()]
|
||||||
|
if len(available_backends) > 0:
|
||||||
|
name = available_backends[0].name
|
||||||
|
if len(available_backends) > 1:
|
||||||
|
logger.info(
|
||||||
|
f"{len(available_backends)} hyperparameter search backends available. Using {name} as the default."
|
||||||
|
)
|
||||||
|
return name
|
||||||
|
raise RuntimeError(
|
||||||
|
"No hyperparameter search backend available.\n"
|
||||||
|
+ "\n".join(
|
||||||
|
f" - To install {backend.name} run {backend.pip_install()}"
|
||||||
|
for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values()
|
||||||
|
)
|
||||||
|
)
|
||||||
@@ -177,15 +177,6 @@ def hp_params(trial):
|
|||||||
raise RuntimeError(f"Unknown type for trial {trial.__class__}")
|
raise RuntimeError(f"Unknown type for trial {trial.__class__}")
|
||||||
|
|
||||||
|
|
||||||
def default_hp_search_backend():
|
|
||||||
if is_optuna_available():
|
|
||||||
return "optuna"
|
|
||||||
elif is_ray_tune_available():
|
|
||||||
return "ray"
|
|
||||||
elif is_sigopt_available():
|
|
||||||
return "sigopt"
|
|
||||||
|
|
||||||
|
|
||||||
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
|
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
|
||||||
import optuna
|
import optuna
|
||||||
|
|
||||||
|
|||||||
@@ -36,18 +36,9 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
|
|||||||
# Integrations must be imported before ML frameworks:
|
# Integrations must be imported before ML frameworks:
|
||||||
# isort: off
|
# isort: off
|
||||||
from .integrations import (
|
from .integrations import (
|
||||||
default_hp_search_backend,
|
|
||||||
get_reporting_integration_callbacks,
|
get_reporting_integration_callbacks,
|
||||||
hp_params,
|
hp_params,
|
||||||
is_fairscale_available,
|
is_fairscale_available,
|
||||||
is_optuna_available,
|
|
||||||
is_ray_tune_available,
|
|
||||||
is_sigopt_available,
|
|
||||||
is_wandb_available,
|
|
||||||
run_hp_search_optuna,
|
|
||||||
run_hp_search_ray,
|
|
||||||
run_hp_search_sigopt,
|
|
||||||
run_hp_search_wandb,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# isort: on
|
# isort: on
|
||||||
@@ -66,6 +57,7 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_d
|
|||||||
from .debug_utils import DebugOption, DebugUnderflowOverflow
|
from .debug_utils import DebugOption, DebugUnderflowOverflow
|
||||||
from .deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled
|
from .deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled
|
||||||
from .dependency_versions_check import dep_version_check
|
from .dependency_versions_check import dep_version_check
|
||||||
|
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
|
||||||
from .modelcard import TrainingSummary
|
from .modelcard import TrainingSummary
|
||||||
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
|
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
|
||||||
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
|
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
|
||||||
@@ -114,7 +106,6 @@ from .trainer_utils import (
|
|||||||
TrainerMemoryTracker,
|
TrainerMemoryTracker,
|
||||||
TrainOutput,
|
TrainOutput,
|
||||||
default_compute_objective,
|
default_compute_objective,
|
||||||
default_hp_space,
|
|
||||||
denumpify_detensorize,
|
denumpify_detensorize,
|
||||||
enable_full_determinism,
|
enable_full_determinism,
|
||||||
find_executable_batch_size,
|
find_executable_batch_size,
|
||||||
@@ -2517,41 +2508,20 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
if backend is None:
|
if backend is None:
|
||||||
backend = default_hp_search_backend()
|
backend = default_hp_search_backend()
|
||||||
if backend is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"At least one of optuna or ray should be installed. "
|
|
||||||
"To install optuna run `pip install optuna`. "
|
|
||||||
"To install ray run `pip install ray[tune]`. "
|
|
||||||
"To install sigopt run `pip install sigopt`."
|
|
||||||
)
|
|
||||||
backend = HPSearchBackend(backend)
|
backend = HPSearchBackend(backend)
|
||||||
if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
|
backend_obj = ALL_HYPERPARAMETER_SEARCH_BACKENDS[backend]()
|
||||||
raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
|
backend_obj.ensure_available()
|
||||||
if backend == HPSearchBackend.RAY and not is_ray_tune_available():
|
|
||||||
raise RuntimeError(
|
|
||||||
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
|
|
||||||
)
|
|
||||||
if backend == HPSearchBackend.SIGOPT and not is_sigopt_available():
|
|
||||||
raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.")
|
|
||||||
if backend == HPSearchBackend.WANDB and not is_wandb_available():
|
|
||||||
raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.")
|
|
||||||
self.hp_search_backend = backend
|
self.hp_search_backend = backend
|
||||||
if self.model_init is None:
|
if self.model_init is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"To use hyperparameter search, you need to pass your model through a model_init function."
|
"To use hyperparameter search, you need to pass your model through a model_init function."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
|
self.hp_space = backend_obj.default_hp_space if hp_space is None else hp_space
|
||||||
self.hp_name = hp_name
|
self.hp_name = hp_name
|
||||||
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
|
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
|
||||||
|
|
||||||
backend_dict = {
|
best_run = backend_obj.run(self, n_trials, direction, **kwargs)
|
||||||
HPSearchBackend.OPTUNA: run_hp_search_optuna,
|
|
||||||
HPSearchBackend.RAY: run_hp_search_ray,
|
|
||||||
HPSearchBackend.SIGOPT: run_hp_search_sigopt,
|
|
||||||
HPSearchBackend.WANDB: run_hp_search_wandb,
|
|
||||||
}
|
|
||||||
best_run = backend_dict[backend](self, n_trials, direction, **kwargs)
|
|
||||||
|
|
||||||
self.hp_search_backend = None
|
self.hp_search_backend = None
|
||||||
return best_run
|
return best_run
|
||||||
|
|||||||
@@ -301,14 +301,6 @@ class HPSearchBackend(ExplicitEnum):
|
|||||||
WANDB = "wandb"
|
WANDB = "wandb"
|
||||||
|
|
||||||
|
|
||||||
default_hp_space = {
|
|
||||||
HPSearchBackend.OPTUNA: default_hp_space_optuna,
|
|
||||||
HPSearchBackend.RAY: default_hp_space_ray,
|
|
||||||
HPSearchBackend.SIGOPT: default_hp_space_sigopt,
|
|
||||||
HPSearchBackend.WANDB: default_hp_space_wandb,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def is_main_process(local_rank):
|
def is_main_process(local_rank):
|
||||||
"""
|
"""
|
||||||
Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
|
Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ from transformers import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
|
from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
ENDPOINT_STAGING,
|
ENDPOINT_STAGING,
|
||||||
TOKEN,
|
TOKEN,
|
||||||
@@ -72,7 +73,7 @@ from transformers.testing_utils import (
|
|||||||
require_wandb,
|
require_wandb,
|
||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend
|
||||||
from transformers.training_args import OptimizerNames
|
from transformers.training_args import OptimizerNames
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
SAFE_WEIGHTS_INDEX_NAME,
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
@@ -2803,3 +2804,11 @@ class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase):
|
|||||||
trainer.hyperparameter_search(
|
trainer.hyperparameter_search(
|
||||||
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="wandb", n_trials=4, anonymous="must"
|
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="wandb", n_trials=4, anonymous="must"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HyperParameterSearchBackendsTest(unittest.TestCase):
|
||||||
|
def test_hyperparameter_search_backends(self):
|
||||||
|
self.assertEqual(
|
||||||
|
list(ALL_HYPERPARAMETER_SEARCH_BACKENDS.keys()),
|
||||||
|
list(HPSearchBackend),
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user