TensorBoard/Wandb/optuna/raytune integration improvements. (#7935)

Improved TensorBoard and Wandb integration, as well as optuna and ray/tune support, with minor modifications to trainer core code.
This commit is contained in:
François Lagunas
2020-10-21 17:18:52 +02:00
committed by GitHub
parent bf162ce8ca
commit e174bfeb34
7 changed files with 344 additions and 23 deletions

View File

@@ -39,6 +39,7 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_d
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available
from .integrations import (
default_hp_search_backend,
hp_params,
is_comet_available,
is_optuna_available,
is_ray_available,
@@ -224,6 +225,7 @@ class Trainer:
model is not None or model_init is not None
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
self.model_init = model_init
self.hp_name = None
if model is None and model_init is not None:
model = self.call_model_init()
self.model = model.to(args.device) if model is not None else None
@@ -508,8 +510,11 @@ class Trainer:
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
""" HP search setup code """
self._trial = trial
if self.hp_search_backend is None or trial is None:
return
params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
for key, value in params.items():
if not hasattr(self.args, key):
@@ -558,7 +563,10 @@ class Trainer:
elif model_init_argcount == 1:
model = self.model_init(trial)
else:
raise Exception("model_init should have 0 or 1 argument.")
raise RuntimeError("model_init should have 0 or 1 argument.")
if model is None:
raise RuntimeError("model_init should not return None.")
return model
@@ -617,6 +625,7 @@ class Trainer:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None
# Check if saved optimizer or scheduler states exist
if (
@@ -702,6 +711,8 @@ class Trainer:
self.callback_handler.optimizer = self.optimizer
self.callback_handler.lr_scheduler = self.lr_scheduler
self.callback_handler.train_dataloader = train_dataloader
self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None
self.state.trial_params = hp_params(trial) if trial is not None else None
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
# to set this after the load.
self.state.max_steps = max_steps
@@ -783,13 +794,13 @@ class Trainer:
self.state.epoch = epoch + (step + 1) / steps_in_epoch
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
if self.control.should_epoch_stop or self.control.should_training_stop:
break
self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
if self.args.tpu_metrics_debug or self.args.debug:
if is_torch_tpu_available():
@@ -823,7 +834,7 @@ class Trainer:
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
def _maybe_log_save_evalute(self, tr_loss, model, trial, epoch):
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
if self.control.should_log:
logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item()
@@ -842,6 +853,7 @@ class Trainer:
if self.control.should_evaluate:
metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
if self.control.should_save:
@@ -857,12 +869,15 @@ class Trainer:
assert model is self.model, f"Model {model} should be a reference to self.model"
# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
if self.hp_search_backend is not None and trial is not None:
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
checkpoint_folder += f"-run-{run_id}"
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder)
else:
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
self.store_flos()
self.store_flos()
self.save_model(output_dir)
# Save optimizer and scheduler
@@ -909,6 +924,7 @@ class Trainer:
n_trials: int = 20,
direction: str = "minimize",
backend: Optional[Union["str", HPSearchBackend]] = None,
hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
**kwargs
) -> BestRun:
"""
@@ -966,13 +982,13 @@ class Trainer:
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
)
self.hp_search_backend = backend
if self.model_init is None:
raise RuntimeError(
"To use hyperparameter search, you need to pass your model through a model_init function."
)
self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
self.hp_name = hp_name
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
@@ -997,12 +1013,12 @@ class Trainer:
FutureWarning,
)
return self._log(logs)
if self.state.epoch is not None:
logs["epoch"] = self.state.epoch
if self._total_flos is not None:
self.store_flos()
logs["total_flos"] = self.state.total_flos
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
output = {**logs, **{"step": self.state.global_step}}
self.state.log_history.append(output)