From b6295b26c500024ec733e18730463a5e94a7b716 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Thu, 22 Jun 2023 20:28:25 +0200 Subject: [PATCH] Refactor hyperparameter search backends (#24384) * Refactor hyperparameter search backends * Simpler refactoring without abstract base class * black * review comments: specify name in class use methods instead of callable class attributes name constant better * review comments: safer bool checking, log multiple available backends * test ALL_HYPERPARAMETER_SEARCH_BACKENDS vs HPSearchBackend in unit test, not module. format with black. * copyright --- src/transformers/__init__.py | 1 + src/transformers/hyperparameter_search.py | 136 ++++++++++++++++++++++ src/transformers/integrations.py | 9 -- src/transformers/trainer.py | 40 +------ src/transformers/trainer_utils.py | 8 -- tests/trainer/test_trainer.py | 11 +- 6 files changed, 152 insertions(+), 53 deletions(-) create mode 100644 src/transformers/hyperparameter_search.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 548e9bf660..7bb7a4342e 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -98,6 +98,7 @@ _import_structure = { "file_utils": [], "generation": ["GenerationConfig", "TextIteratorStreamer", "TextStreamer"], "hf_argparser": ["HfArgumentParser"], + "hyperparameter_search": [], "image_transforms": [], "integrations": [ "is_clearml_available", diff --git a/src/transformers/hyperparameter_search.py b/src/transformers/hyperparameter_search.py new file mode 100644 index 0000000000..f0f5f46a0d --- /dev/null +++ b/src/transformers/hyperparameter_search.py @@ -0,0 +1,136 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .integrations import ( + is_optuna_available, + is_ray_available, + is_sigopt_available, + is_wandb_available, + run_hp_search_optuna, + run_hp_search_ray, + run_hp_search_sigopt, + run_hp_search_wandb, +) +from .trainer_utils import ( + HPSearchBackend, + default_hp_space_optuna, + default_hp_space_ray, + default_hp_space_sigopt, + default_hp_space_wandb, +) +from .utils import logging + + +logger = logging.get_logger(__name__) + + +class HyperParamSearchBackendBase: + name: str + pip_package: str = None + + def is_available(self): + raise NotImplementedError + + def run(self, trainer, n_trials: int, direction: str, **kwargs): + raise NotImplementedError + + def default_hp_space(self, trial): + raise NotImplementedError + + def ensure_available(self): + if not self.is_available(): + raise RuntimeError( + f"You picked the {self.name} backend, but it is not installed. Run {self.pip_install()}." + ) + + @classmethod + def pip_install(cls): + return f"`pip install {cls.pip_package or cls.name}`" + + +class OptunaBackend(HyperParamSearchBackendBase): + name = "optuna" + + def is_available(self): + return is_optuna_available() + + def run(self, trainer, n_trials: int, direction: str, **kwargs): + return run_hp_search_optuna(trainer, n_trials, direction, **kwargs) + + def default_hp_space(self, trial): + return default_hp_space_optuna(trial) + + +class RayTuneBackend(HyperParamSearchBackendBase): + name = "ray" + pip_package = "'ray[tune]'" + + def is_available(self): + return is_ray_available() + + def run(self, trainer, n_trials: int, direction: str, **kwargs): + return run_hp_search_ray(trainer, n_trials, direction, **kwargs) + + def default_hp_space(self, trial): + return default_hp_space_ray(trial) + + +class SigOptBackend(HyperParamSearchBackendBase): + name = "sigopt" + + def is_available(self): + return is_sigopt_available() + + def run(self, trainer, n_trials: int, direction: str, **kwargs): + return run_hp_search_sigopt(trainer, n_trials, direction, **kwargs) + + def default_hp_space(self, trial): + return default_hp_space_sigopt(trial) + + +class WandbBackend(HyperParamSearchBackendBase): + name = "wandb" + + def is_available(self): + return is_wandb_available() + + def run(self, trainer, n_trials: int, direction: str, **kwargs): + return run_hp_search_wandb(trainer, n_trials, direction, **kwargs) + + def default_hp_space(self, trial): + return default_hp_space_wandb(trial) + + +ALL_HYPERPARAMETER_SEARCH_BACKENDS = { + HPSearchBackend(backend.name): backend for backend in [OptunaBackend, RayTuneBackend, SigOptBackend, WandbBackend] +} + + +def default_hp_search_backend() -> str: + available_backends = [backend for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values() if backend.is_available()] + if len(available_backends) > 0: + name = available_backends[0].name + if len(available_backends) > 1: + logger.info( + f"{len(available_backends)} hyperparameter search backends available. Using {name} as the default." + ) + return name + raise RuntimeError( + "No hyperparameter search backend available.\n" + + "\n".join( + f" - To install {backend.name} run {backend.pip_install()}" + for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values() + ) + ) diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index dc3d00d4d1..a4af915fdc 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -177,15 +177,6 @@ def hp_params(trial): raise RuntimeError(f"Unknown type for trial {trial.__class__}") -def default_hp_search_backend(): - if is_optuna_available(): - return "optuna" - elif is_ray_tune_available(): - return "ray" - elif is_sigopt_available(): - return "sigopt" - - def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun: import optuna diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 73ef46655b..d76f9b4971 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -36,18 +36,9 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un # Integrations must be imported before ML frameworks: # isort: off from .integrations import ( - default_hp_search_backend, get_reporting_integration_callbacks, hp_params, is_fairscale_available, - is_optuna_available, - is_ray_tune_available, - is_sigopt_available, - is_wandb_available, - run_hp_search_optuna, - run_hp_search_ray, - run_hp_search_sigopt, - run_hp_search_wandb, ) # isort: on @@ -66,6 +57,7 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_d from .debug_utils import DebugOption, DebugUnderflowOverflow from .deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled from .dependency_versions_check import dep_version_check +from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend from .modelcard import TrainingSummary from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES @@ -114,7 +106,6 @@ from .trainer_utils import ( TrainerMemoryTracker, TrainOutput, default_compute_objective, - default_hp_space, denumpify_detensorize, enable_full_determinism, find_executable_batch_size, @@ -2517,41 +2508,20 @@ class Trainer: """ if backend is None: backend = default_hp_search_backend() - if backend is None: - raise RuntimeError( - "At least one of optuna or ray should be installed. " - "To install optuna run `pip install optuna`. " - "To install ray run `pip install ray[tune]`. " - "To install sigopt run `pip install sigopt`." - ) backend = HPSearchBackend(backend) - if backend == HPSearchBackend.OPTUNA and not is_optuna_available(): - raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.") - if backend == HPSearchBackend.RAY and not is_ray_tune_available(): - raise RuntimeError( - "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`." - ) - if backend == HPSearchBackend.SIGOPT and not is_sigopt_available(): - raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.") - if backend == HPSearchBackend.WANDB and not is_wandb_available(): - raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.") + backend_obj = ALL_HYPERPARAMETER_SEARCH_BACKENDS[backend]() + backend_obj.ensure_available() self.hp_search_backend = backend if self.model_init is None: raise RuntimeError( "To use hyperparameter search, you need to pass your model through a model_init function." ) - self.hp_space = default_hp_space[backend] if hp_space is None else hp_space + self.hp_space = backend_obj.default_hp_space if hp_space is None else hp_space self.hp_name = hp_name self.compute_objective = default_compute_objective if compute_objective is None else compute_objective - backend_dict = { - HPSearchBackend.OPTUNA: run_hp_search_optuna, - HPSearchBackend.RAY: run_hp_search_ray, - HPSearchBackend.SIGOPT: run_hp_search_sigopt, - HPSearchBackend.WANDB: run_hp_search_wandb, - } - best_run = backend_dict[backend](self, n_trials, direction, **kwargs) + best_run = backend_obj.run(self, n_trials, direction, **kwargs) self.hp_search_backend = None return best_run diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index d6008b53e7..13e72b3d44 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -301,14 +301,6 @@ class HPSearchBackend(ExplicitEnum): WANDB = "wandb" -default_hp_space = { - HPSearchBackend.OPTUNA: default_hp_space_optuna, - HPSearchBackend.RAY: default_hp_space_ray, - HPSearchBackend.SIGOPT: default_hp_space_sigopt, - HPSearchBackend.WANDB: default_hp_space_wandb, -} - - def is_main_process(local_rank): """ Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index ea823d06a7..3442b52b01 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -42,6 +42,7 @@ from transformers import ( is_torch_available, logging, ) +from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS from transformers.testing_utils import ( ENDPOINT_STAGING, TOKEN, @@ -72,7 +73,7 @@ from transformers.testing_utils import ( require_wandb, slow, ) -from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend from transformers.training_args import OptimizerNames from transformers.utils import ( SAFE_WEIGHTS_INDEX_NAME, @@ -2803,3 +2804,11 @@ class TrainerHyperParameterWandbIntegrationTest(unittest.TestCase): trainer.hyperparameter_search( direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="wandb", n_trials=4, anonymous="must" ) + + +class HyperParameterSearchBackendsTest(unittest.TestCase): + def test_hyperparameter_search_backends(self): + self.assertEqual( + list(ALL_HYPERPARAMETER_SEARCH_BACKENDS.keys()), + list(HPSearchBackend), + )