From 23f9611c16f8bc528fab74690e32274809697f56 Mon Sep 17 00:00:00 2001 From: krfricke Date: Mon, 31 Aug 2020 19:38:46 +0100 Subject: [PATCH] Add checkpointing to Ray Tune HPO (#6747) * Introduce HPO checkpointing for PBT * Moved checkpoint saving * Fixed checkpoint subdir pass * Fixed style * Enable/disable checkpointing, check conditions for various tune schedulers incl. PBT * Adjust number of GPUs to number of jobs * Avoid mode pickling in ray * Move hp search to integrations --- src/transformers/integrations.py | 100 +++++++++++++++++++++++++++++- src/transformers/trainer.py | 50 +++++---------- src/transformers/trainer_utils.py | 8 ++- 3 files changed, 121 insertions(+), 37 deletions(-) diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 3e597b029e..2382d50241 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -1,7 +1,13 @@ # Integrations with other Python libraries - import os +import numpy as np + +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, HPSearchBackend +from transformers.utils import logging + + +logger = logging.get_logger(__name__) try: import comet_ml # noqa: F401 @@ -75,3 +81,95 @@ def default_hp_search_backend(): return "optuna" elif is_ray_available(): return "ray" + + +def run_hp_search(trainer, n_trials, direction, kwargs): + def _objective(trial, checkpoint_dir=None): + model_path = None + if checkpoint_dir: + for subdir in os.listdir(checkpoint_dir): + if subdir.startswith(PREFIX_CHECKPOINT_DIR): + model_path = os.path.join(checkpoint_dir, subdir) + trainer.objective = None + trainer.train(model_path=model_path, trial=trial) + # If there hasn't been any evaluation during the training loop. + if getattr(trainer, "objective", None) is None: + metrics = trainer.evaluate() + trainer.objective = trainer.compute_objective(metrics) + if trainer.hp_search_backend == HPSearchBackend.RAY: + trainer._tune_save_checkpoint() + ray.tune.report(objective=trainer.objective) + return trainer.objective + + if trainer.hp_search_backend == HPSearchBackend.OPTUNA: + timeout = kwargs.pop("timeout", None) + n_jobs = kwargs.pop("n_jobs", 1) + study = optuna.create_study(direction=direction, **kwargs) + study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs) + best_trial = study.best_trial + best_run = BestRun(str(best_trial.number), best_trial.value, best_trial.params) + elif trainer.hp_search_backend == HPSearchBackend.RAY: + # The model and TensorBoard writer do not pickle so we have to remove them (if they exists) + # while doing the ray hp search. + _tb_writer = trainer.tb_writer + trainer.tb_writer = None + trainer.model = None + # Setup default `resources_per_trial` and `reporter`. + if "resources_per_trial" not in kwargs and trainer.args.n_gpu > 0: + # `args.n_gpu` is considered the total number of GPUs that will be split + # among the `n_jobs` + n_jobs = int(kwargs.pop("n_jobs", 1)) + num_gpus_per_trial = trainer.args.n_gpu + if num_gpus_per_trial / n_jobs >= 1: + num_gpus_per_trial = int(np.ceil(num_gpus_per_trial / n_jobs)) + kwargs["resources_per_trial"] = {"gpu": num_gpus_per_trial} + + if "reporter" not in kwargs: + from ray.tune import CLIReporter + + kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"]) + if "keep_checkpoints_num" in kwargs and kwargs["keep_checkpoints_num"] > 0: + # `keep_checkpoints_num=0` would disabled checkpointing + trainer.use_tune_checkpoints = True + if kwargs["keep_checkpoints_num"] > 1: + logger.warning( + "Currently keeping {} checkpoints for each trial. Checkpoints are usually huge, " + "consider setting `keep_checkpoints_num=1`." + ) + if "scheduler" in kwargs: + from ray.tune.schedulers import ( + ASHAScheduler, + HyperBandForBOHB, + MedianStoppingRule, + PopulationBasedTraining, + ) + + # Check if checkpointing is enabled for PopulationBasedTraining + if isinstance(kwargs["scheduler"], PopulationBasedTraining): + if not trainer.use_tune_checkpoints: + logger.warning( + "You are using PopulationBasedTraining but you haven't enabled checkpointing. " + "This means your trials will train from scratch everytime they are exploiting " + "new configurations. Consider enabling checkpointing by passing " + "`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`." + ) + + # Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting. + if isinstance( + kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining) + ) and (not trainer.args.do_eval or not trainer.args.evaluate_during_training): + raise RuntimeError( + "You are using {cls} as a scheduler but you haven't enabled evaluation during training. " + "This means your trials will not report intermediate results to Ray Tune, and " + "can thus not be stopped early or used to exploit other trials parameters. " + "If this is what you want, do not use {cls}. If you would like to use {cls}, " + "make sure you pass `do_eval=True` and `evaluate_during_training=True` in the " + "Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__) + ) + + analysis = ray.tune.run(_objective, config=trainer.hp_space(None), num_samples=n_trials, **kwargs) + best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3]) + best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config) + trainer.tb_writer = _tb_writer + + return best_run diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 74b00e7d1d..428259bb48 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -27,6 +27,7 @@ from .integrations import ( is_ray_available, is_tensorboard_available, is_wandb_available, + run_hp_search, ) from .modeling_utils import PreTrainedModel from .optimization import AdamW, get_linear_schedule_with_warmup @@ -295,6 +296,7 @@ class Trainer: if self.args.fp16 and _use_native_amp: self.scaler = torch.cuda.amp.GradScaler() self.hp_search_backend = None + self.use_tune_checkpoints = False def _remove_unused_columns(self, dataset: "nlp.Dataset", description: Optional[str] = None): if not self.args.remove_unused_columns: @@ -544,8 +546,21 @@ class Trainer: if trial.should_prune(): raise optuna.TrialPruned() elif self.hp_search_backend == HPSearchBackend.RAY: + if self.global_step % self.args.save_steps == 0: + self._tune_save_checkpoint() tune.report(objective=self.objective, **metrics) + def _tune_save_checkpoint(self): + if not self.use_tune_checkpoints: + return + with tune.checkpoint_dir(step=self.global_step) as checkpoint_dir: + self.args.output_dir = checkpoint_dir + output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}") + self.save_model(output_dir) + if self.is_world_master(): + torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) + def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None): """ Main training entry point. @@ -869,40 +884,7 @@ class Trainer: self.hp_space = default_hp_space[backend] if hp_space is None else hp_space self.compute_objective = default_compute_objective if compute_objective is None else compute_objective - def _objective(trial): - self.objective = None - self.train(trial=trial) - # If there hasn't been any evaluation during the training loop. - if getattr(self, "objective", None) is None: - metrics = self.evaluate() - self.objective = self.compute_objective(metrics) - if self.hp_search_backend == HPSearchBackend.RAY: - tune.report(objective=self.objective) - return self.objective - - if self.hp_search_backend == HPSearchBackend.OPTUNA: - timeout = kwargs.pop("timeout", None) - n_jobs = kwargs.pop("n_jobs", 1) - study = optuna.create_study(direction=direction, **kwargs) - study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs) - best_trial = study.best_trial - best_run = BestRun(str(best_trial.number), best_trial.value, best_trial.params) - elif self.hp_search_backend == HPSearchBackend.RAY: - # The TensorBoard writer does not pickle so we have to remove it (if it exists) while doing the ray hp - # search. - _tb_writer = self.tb_writer - self.tb_writer = None - # Setup default `resources_per_trial` and `reporter`. - if "resources_per_trial" not in kwargs and self.args.n_gpu > 0: - kwargs["resources_per_trial"] = {"gpu": self.args.n_gpu} - if "reporter" not in kwargs: - from ray.tune import CLIReporter - - kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"]) - analysis = tune.run(_objective, config=self.hp_space(None), num_samples=n_trials, **kwargs) - best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3]) - best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config) - self.tb_writer = _tb_writer + best_run = run_hp_search(self, n_trials, direction, kwargs) self.hp_search_backend = None return best_run diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index d5556f16c3..f9b560b402 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -4,7 +4,6 @@ from typing import Any, Dict, NamedTuple, Optional import numpy as np from .file_utils import is_tf_available, is_torch_available -from .integrations import is_ray_available from .tokenization_utils_base import ExplicitEnum @@ -93,6 +92,9 @@ def default_compute_objective(metrics: Dict[str, float]) -> float: def default_hp_space_optuna(trial) -> Dict[str, float]: + from .integrations import is_optuna_available + + assert is_optuna_available(), "This function needs Optuna installed: `pip install optuna`" return { "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True), "num_train_epochs": trial.suggest_int("num_train_epochs", 1, 5), @@ -102,12 +104,14 @@ def default_hp_space_optuna(trial) -> Dict[str, float]: def default_hp_space_ray(trial) -> Dict[str, float]: + from .integrations import is_ray_available + assert is_ray_available(), "This function needs ray installed: `pip install ray[tune]`" from ray import tune return { "learning_rate": tune.loguniform(1e-6, 1e-4), - "num_train_epochs": tune.choice(range(1, 6)), + "num_train_epochs": tune.choice(list(range(1, 6))), "seed": tune.uniform(1, 40), "per_device_train_batch_size": tune.choice([4, 8, 16, 32, 64]), }