[integration] Update Ray Tune integration for Ray 2.7 (#26499)
* fix tune integration for ray 2.7+ Signed-off-by: Justin Yu <justinvyu@anyscale.com> * add version check for ray tune backend availability Signed-off-by: Justin Yu <justinvyu@anyscale.com> * missing import Signed-off-by: Justin Yu <justinvyu@anyscale.com> * pin min version instead Signed-off-by: Justin Yu <justinvyu@anyscale.com> * address comments Signed-off-by: Justin Yu <justinvyu@anyscale.com> * some fixes Signed-off-by: Justin Yu <justinvyu@anyscale.com> * fix unnecessary final checkpoint Signed-off-by: Justin Yu <justinvyu@anyscale.com> * fix lint Signed-off-by: Justin Yu <justinvyu@anyscale.com> * dep table fix Signed-off-by: Justin Yu <justinvyu@anyscale.com> * fix lint Signed-off-by: Justin Yu <justinvyu@anyscale.com> --------- Signed-off-by: Justin Yu <justinvyu@anyscale.com>
This commit is contained in:
2
setup.py
2
setup.py
@@ -149,7 +149,7 @@ _deps = [
|
|||||||
"pytest-timeout",
|
"pytest-timeout",
|
||||||
"pytest-xdist",
|
"pytest-xdist",
|
||||||
"python>=3.8.0",
|
"python>=3.8.0",
|
||||||
"ray[tune]",
|
"ray[tune]>=2.7.0",
|
||||||
"regex!=2019.12.17",
|
"regex!=2019.12.17",
|
||||||
"requests",
|
"requests",
|
||||||
"rhoknp>=1.1.0,<1.3.1",
|
"rhoknp>=1.1.0,<1.3.1",
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ deps = {
|
|||||||
"pytest-timeout": "pytest-timeout",
|
"pytest-timeout": "pytest-timeout",
|
||||||
"pytest-xdist": "pytest-xdist",
|
"pytest-xdist": "pytest-xdist",
|
||||||
"python": "python>=3.8.0",
|
"python": "python>=3.8.0",
|
||||||
"ray[tune]": "ray[tune]",
|
"ray[tune]": "ray[tune]>=2.7.0",
|
||||||
"regex": "regex!=2019.12.17",
|
"regex": "regex!=2019.12.17",
|
||||||
"requests": "requests",
|
"requests": "requests",
|
||||||
"rhoknp": "rhoknp>=1.1.0,<1.3.1",
|
"rhoknp": "rhoknp>=1.1.0,<1.3.1",
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
from .integrations import (
|
from .integrations import (
|
||||||
is_optuna_available,
|
is_optuna_available,
|
||||||
is_ray_available,
|
is_ray_tune_available,
|
||||||
is_sigopt_available,
|
is_sigopt_available,
|
||||||
is_wandb_available,
|
is_wandb_available,
|
||||||
run_hp_search_optuna,
|
run_hp_search_optuna,
|
||||||
@@ -81,7 +81,7 @@ class RayTuneBackend(HyperParamSearchBackendBase):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_available():
|
def is_available():
|
||||||
return is_ray_available()
|
return is_ray_tune_available()
|
||||||
|
|
||||||
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
||||||
return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
|
return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
|
||||||
|
|||||||
@@ -236,8 +236,9 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
|
|||||||
|
|
||||||
def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
|
def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
|
||||||
import ray
|
import ray
|
||||||
|
import ray.train
|
||||||
|
|
||||||
def _objective(trial, local_trainer, checkpoint_dir=None):
|
def _objective(trial: dict, local_trainer):
|
||||||
try:
|
try:
|
||||||
from transformers.utils.notebook import NotebookProgressCallback
|
from transformers.utils.notebook import NotebookProgressCallback
|
||||||
|
|
||||||
@@ -246,19 +247,34 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
|
|||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
checkpoint = None
|
|
||||||
if checkpoint_dir:
|
|
||||||
for subdir in os.listdir(checkpoint_dir):
|
|
||||||
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
|
|
||||||
checkpoint = os.path.join(checkpoint_dir, subdir)
|
|
||||||
local_trainer.objective = None
|
local_trainer.objective = None
|
||||||
local_trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
|
|
||||||
|
checkpoint = ray.train.get_checkpoint()
|
||||||
|
if checkpoint:
|
||||||
|
# Upon trial resume, the local_trainer's objective gets reset to None.
|
||||||
|
# If `local_trainer.train` is a noop (training has already reached
|
||||||
|
# the target number of epochs/steps), then this would
|
||||||
|
# trigger an unnecessary extra checkpoint at the end of training.
|
||||||
|
# -> Set the objective to a dummy value upon resume as a workaround.
|
||||||
|
local_trainer.objective = "objective"
|
||||||
|
|
||||||
|
with checkpoint.as_directory() as checkpoint_dir:
|
||||||
|
checkpoint_path = next(Path(checkpoint_dir).glob(f"{PREFIX_CHECKPOINT_DIR}*")).as_posix()
|
||||||
|
local_trainer.train(resume_from_checkpoint=checkpoint_path, trial=trial)
|
||||||
|
else:
|
||||||
|
local_trainer.train(trial=trial)
|
||||||
|
|
||||||
# If there hasn't been any evaluation during the training loop.
|
# If there hasn't been any evaluation during the training loop.
|
||||||
if getattr(local_trainer, "objective", None) is None:
|
if getattr(local_trainer, "objective", None) is None:
|
||||||
metrics = local_trainer.evaluate()
|
metrics = local_trainer.evaluate()
|
||||||
local_trainer.objective = local_trainer.compute_objective(metrics)
|
local_trainer.objective = local_trainer.compute_objective(metrics)
|
||||||
local_trainer._tune_save_checkpoint()
|
|
||||||
ray.tune.report(objective=local_trainer.objective, **metrics, done=True)
|
metrics.update({"objective": local_trainer.objective, "done": True})
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
|
||||||
|
local_trainer._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
|
||||||
|
checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
|
||||||
|
ray.train.report(metrics, checkpoint=checkpoint)
|
||||||
|
|
||||||
if not trainer._memory_tracker.skip_memory_metrics:
|
if not trainer._memory_tracker.skip_memory_metrics:
|
||||||
from ..trainer_utils import TrainerMemoryTracker
|
from ..trainer_utils import TrainerMemoryTracker
|
||||||
@@ -296,28 +312,10 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
|
|||||||
from ray.tune import CLIReporter
|
from ray.tune import CLIReporter
|
||||||
|
|
||||||
kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
|
kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
|
||||||
if "keep_checkpoints_num" in kwargs and kwargs["keep_checkpoints_num"] > 0:
|
|
||||||
# `keep_checkpoints_num=0` would disabled checkpointing
|
|
||||||
trainer.use_tune_checkpoints = True
|
|
||||||
if kwargs["keep_checkpoints_num"] > 1:
|
|
||||||
logger.warning(
|
|
||||||
f"Currently keeping {kwargs['keep_checkpoints_num']} checkpoints for each trial. "
|
|
||||||
"Checkpoints are usually huge, "
|
|
||||||
"consider setting `keep_checkpoints_num=1`."
|
|
||||||
)
|
|
||||||
if "scheduler" in kwargs:
|
if "scheduler" in kwargs:
|
||||||
from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining
|
from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining
|
||||||
|
|
||||||
# Check if checkpointing is enabled for PopulationBasedTraining
|
|
||||||
if isinstance(kwargs["scheduler"], PopulationBasedTraining):
|
|
||||||
if not trainer.use_tune_checkpoints:
|
|
||||||
logger.warning(
|
|
||||||
"You are using PopulationBasedTraining but you haven't enabled checkpointing. "
|
|
||||||
"This means your trials will train from scratch everytime they are exploiting "
|
|
||||||
"new configurations. Consider enabling checkpointing by passing "
|
|
||||||
"`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
|
# Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
|
||||||
if isinstance(
|
if isinstance(
|
||||||
kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
|
kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import random
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
@@ -595,7 +596,6 @@ class Trainer:
|
|||||||
# returned to 0 every time flos need to be logged
|
# returned to 0 every time flos need to be logged
|
||||||
self.current_flos = 0
|
self.current_flos = 0
|
||||||
self.hp_search_backend = None
|
self.hp_search_backend = None
|
||||||
self.use_tune_checkpoints = False
|
|
||||||
default_label_names = find_labels(self.model.__class__)
|
default_label_names = find_labels(self.model.__class__)
|
||||||
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
|
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
|
||||||
self.can_return_loss = can_return_loss(self.model.__class__)
|
self.can_return_loss = can_return_loss(self.model.__class__)
|
||||||
@@ -1201,7 +1201,8 @@ class Trainer:
|
|||||||
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
|
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
|
||||||
if self.hp_search_backend is None or trial is None:
|
if self.hp_search_backend is None or trial is None:
|
||||||
return
|
return
|
||||||
self.objective = self.compute_objective(metrics.copy())
|
metrics = metrics.copy()
|
||||||
|
self.objective = self.compute_objective(metrics)
|
||||||
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||||
import optuna
|
import optuna
|
||||||
|
|
||||||
@@ -1211,18 +1212,17 @@ class Trainer:
|
|||||||
self.callback_handler.on_train_end(self.args, self.state, self.control)
|
self.callback_handler.on_train_end(self.args, self.state, self.control)
|
||||||
raise optuna.TrialPruned()
|
raise optuna.TrialPruned()
|
||||||
elif self.hp_search_backend == HPSearchBackend.RAY:
|
elif self.hp_search_backend == HPSearchBackend.RAY:
|
||||||
from ray import tune
|
import ray.train
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
|
||||||
|
checkpoint = None
|
||||||
if self.control.should_save:
|
if self.control.should_save:
|
||||||
self._tune_save_checkpoint()
|
self._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
|
||||||
tune.report(objective=self.objective, **metrics)
|
checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
|
||||||
|
metrics["objective"] = self.objective
|
||||||
|
ray.train.report(metrics, checkpoint=checkpoint)
|
||||||
|
|
||||||
def _tune_save_checkpoint(self):
|
def _tune_save_checkpoint(self, checkpoint_dir: str):
|
||||||
from ray import tune
|
|
||||||
|
|
||||||
if not self.use_tune_checkpoints:
|
|
||||||
return
|
|
||||||
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
|
|
||||||
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
|
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
|
||||||
self.save_model(output_dir, _internal_call=True)
|
self.save_model(output_dir, _internal_call=True)
|
||||||
if self.args.should_save:
|
if self.args.should_save:
|
||||||
@@ -2004,9 +2004,9 @@ class Trainer:
|
|||||||
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||||
run_id = trial.number
|
run_id = trial.number
|
||||||
elif self.hp_search_backend == HPSearchBackend.RAY:
|
elif self.hp_search_backend == HPSearchBackend.RAY:
|
||||||
from ray import tune
|
import ray.train
|
||||||
|
|
||||||
run_id = tune.get_trial_id()
|
run_id = ray.train.get_context().get_trial_id()
|
||||||
elif self.hp_search_backend == HPSearchBackend.SIGOPT:
|
elif self.hp_search_backend == HPSearchBackend.SIGOPT:
|
||||||
run_id = trial.id
|
run_id = trial.id
|
||||||
elif self.hp_search_backend == HPSearchBackend.WANDB:
|
elif self.hp_search_backend == HPSearchBackend.WANDB:
|
||||||
|
|||||||
Reference in New Issue
Block a user