Ray Tune Integration Bug Fixes (#10406)

* fixes

* update resources

* formatting

* remove import

* add log statement

* use fstring

* add period

* Update src/transformers/integrations.py
This commit is contained in:
Amog Kamsetty
2021-02-26 16:06:08 -08:00
committed by GitHub
parent 98569d4ba2
commit a85eb616f7
2 changed files with 30 additions and 22 deletions

View File

@@ -17,7 +17,6 @@ Integrations with other Python libraries.
import importlib.util import importlib.util
import io import io
import json import json
import math
import numbers import numbers
import os import os
import re import re
@@ -174,16 +173,23 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
_tb_writer = trainer.pop_callback(TensorBoardCallback) _tb_writer = trainer.pop_callback(TensorBoardCallback)
trainer.model = None trainer.model = None
# Setup default `resources_per_trial` and `reporter`. # Setup default `resources_per_trial`.
if "resources_per_trial" not in kwargs and trainer.args.n_gpu > 0: if "resources_per_trial" not in kwargs:
# `args.n_gpu` is considered the total number of GPUs that will be split # Default to 1 CPU and 1 GPU (if applicable) per trial.
# among the `n_jobs` kwargs["resources_per_trial"] = {"cpu": 1}
n_jobs = int(kwargs.pop("n_jobs", 1)) if trainer.args.n_gpu > 0:
num_gpus_per_trial = trainer.args.n_gpu kwargs["resources_per_trial"]["gpu"] = 1
if num_gpus_per_trial / n_jobs >= 1: resource_msg = "1 CPU" + (" and 1 GPU" if trainer.args.n_gpu > 0 else "")
num_gpus_per_trial = int(math.ceil(num_gpus_per_trial / n_jobs)) logger.info(
kwargs["resources_per_trial"] = {"gpu": num_gpus_per_trial} "No `resources_per_trial` arg was passed into "
"`hyperparameter_search`. Setting it to a default value "
f"of {resource_msg} for each trial."
)
# Make sure each trainer only uses GPUs that were allocated per trial.
gpus_per_trial = kwargs["resources_per_trial"].get("gpu", 0)
trainer.args._n_gpu = gpus_per_trial
# Setup default `progress_reporter`.
if "progress_reporter" not in kwargs: if "progress_reporter" not in kwargs:
from ray.tune import CLIReporter from ray.tune import CLIReporter
@@ -193,7 +199,8 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
trainer.use_tune_checkpoints = True trainer.use_tune_checkpoints = True
if kwargs["keep_checkpoints_num"] > 1: if kwargs["keep_checkpoints_num"] > 1:
logger.warning( logger.warning(
"Currently keeping {} checkpoints for each trial. Checkpoints are usually huge, " f"Currently keeping {kwargs['keep_checkpoint_num']} checkpoints for each trial. "
"Checkpoints are usually huge, "
"consider setting `keep_checkpoints_num=1`." "consider setting `keep_checkpoints_num=1`."
) )
if "scheduler" in kwargs: if "scheduler" in kwargs:

View File

@@ -707,7 +707,7 @@ class Trainer:
elif self.hp_search_backend == HPSearchBackend.RAY: elif self.hp_search_backend == HPSearchBackend.RAY:
from ray import tune from ray import tune
if self.state.global_step % self.args.save_steps == 0: if self.control.should_save:
self._tune_save_checkpoint() self._tune_save_checkpoint()
tune.report(objective=self.objective, **metrics) tune.report(objective=self.objective, **metrics)
@@ -717,8 +717,7 @@ class Trainer:
if not self.use_tune_checkpoints: if not self.use_tune_checkpoints:
return return
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir: with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
self.args.output_dir = checkpoint_dir output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir) self.save_model(output_dir)
if self.is_world_process_zero(): if self.is_world_process_zero():
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json")) self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
@@ -1201,12 +1200,12 @@ class Trainer:
run_id = tune.get_trial_id() run_id = tune.get_trial_id()
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}" run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder) run_dir = os.path.join(self.args.output_dir, run_name)
else: else:
output_dir = os.path.join(self.args.output_dir, checkpoint_folder) run_dir = self.args.output_dir
self.store_flos() self.store_flos()
output_dir = os.path.join(run_dir, checkpoint_folder)
self.save_model(output_dir) self.save_model(output_dir)
if self.deepspeed: if self.deepspeed:
self.deepspeed.save_checkpoint(output_dir) self.deepspeed.save_checkpoint(output_dir)
@@ -1250,7 +1249,7 @@ class Trainer:
# Maybe delete some older checkpoints. # Maybe delete some older checkpoints.
if self.is_world_process_zero(): if self.is_world_process_zero():
self._rotate_checkpoints(use_mtime=True) self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
def _load_optimizer_and_scheduler(self, checkpoint): def _load_optimizer_and_scheduler(self, checkpoint):
"""If optimizer and scheduler states exist, load them.""" """If optimizer and scheduler states exist, load them."""
@@ -1559,10 +1558,12 @@ class Trainer:
else: else:
self.state.total_flos = self._total_flos self.state.total_flos = self._total_flos
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]: def _sorted_checkpoints(
self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
) -> List[str]:
ordering_and_checkpoint_path = [] ordering_and_checkpoint_path = []
glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")] glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*")]
for path in glob_checkpoints: for path in glob_checkpoints:
if use_mtime: if use_mtime:
@@ -1583,12 +1584,12 @@ class Trainer:
) )
return checkpoints_sorted return checkpoints_sorted
def _rotate_checkpoints(self, use_mtime=False) -> None: def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
if self.args.save_total_limit is None or self.args.save_total_limit <= 0: if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
return return
# Check if we should delete older checkpoint(s) # Check if we should delete older checkpoint(s)
checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime) checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
if len(checkpoints_sorted) <= self.args.save_total_limit: if len(checkpoints_sorted) <= self.args.save_total_limit:
return return