TensorBoard/Wandb/optuna/raytune integration improvements. (#7935)
Improved TensorBoard and Wandb integration, as well as optuna and ray/tune support, with minor modifications to trainer core code.
This commit is contained in:
@@ -39,6 +39,7 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_d
|
||||
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available
|
||||
from .integrations import (
|
||||
default_hp_search_backend,
|
||||
hp_params,
|
||||
is_comet_available,
|
||||
is_optuna_available,
|
||||
is_ray_available,
|
||||
@@ -224,6 +225,7 @@ class Trainer:
|
||||
model is not None or model_init is not None
|
||||
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
|
||||
self.model_init = model_init
|
||||
self.hp_name = None
|
||||
if model is None and model_init is not None:
|
||||
model = self.call_model_init()
|
||||
self.model = model.to(args.device) if model is not None else None
|
||||
@@ -508,8 +510,11 @@ class Trainer:
|
||||
|
||||
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
|
||||
""" HP search setup code """
|
||||
self._trial = trial
|
||||
|
||||
if self.hp_search_backend is None or trial is None:
|
||||
return
|
||||
|
||||
params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
|
||||
for key, value in params.items():
|
||||
if not hasattr(self.args, key):
|
||||
@@ -558,7 +563,10 @@ class Trainer:
|
||||
elif model_init_argcount == 1:
|
||||
model = self.model_init(trial)
|
||||
else:
|
||||
raise Exception("model_init should have 0 or 1 argument.")
|
||||
raise RuntimeError("model_init should have 0 or 1 argument.")
|
||||
|
||||
if model is None:
|
||||
raise RuntimeError("model_init should not return None.")
|
||||
|
||||
return model
|
||||
|
||||
@@ -617,6 +625,7 @@ class Trainer:
|
||||
|
||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||
self.state = TrainerState()
|
||||
self.state.is_hyper_param_search = trial is not None
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
if (
|
||||
@@ -702,6 +711,8 @@ class Trainer:
|
||||
self.callback_handler.optimizer = self.optimizer
|
||||
self.callback_handler.lr_scheduler = self.lr_scheduler
|
||||
self.callback_handler.train_dataloader = train_dataloader
|
||||
self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None
|
||||
self.state.trial_params = hp_params(trial) if trial is not None else None
|
||||
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
|
||||
# to set this after the load.
|
||||
self.state.max_steps = max_steps
|
||||
@@ -783,13 +794,13 @@ class Trainer:
|
||||
self.state.epoch = epoch + (step + 1) / steps_in_epoch
|
||||
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
|
||||
|
||||
self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
|
||||
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
|
||||
|
||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||
break
|
||||
|
||||
self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
|
||||
self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
|
||||
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
|
||||
|
||||
if self.args.tpu_metrics_debug or self.args.debug:
|
||||
if is_torch_tpu_available():
|
||||
@@ -823,7 +834,7 @@ class Trainer:
|
||||
|
||||
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
|
||||
|
||||
def _maybe_log_save_evalute(self, tr_loss, model, trial, epoch):
|
||||
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
|
||||
if self.control.should_log:
|
||||
logs: Dict[str, float] = {}
|
||||
tr_loss_scalar = tr_loss.item()
|
||||
@@ -842,6 +853,7 @@ class Trainer:
|
||||
if self.control.should_evaluate:
|
||||
metrics = self.evaluate()
|
||||
self._report_to_hp_search(trial, epoch, metrics)
|
||||
|
||||
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
||||
|
||||
if self.control.should_save:
|
||||
@@ -857,12 +869,15 @@ class Trainer:
|
||||
assert model is self.model, f"Model {model} should be a reference to self.model"
|
||||
# Save model checkpoint
|
||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||
|
||||
if self.hp_search_backend is not None and trial is not None:
|
||||
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
|
||||
checkpoint_folder += f"-run-{run_id}"
|
||||
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
|
||||
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
|
||||
output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder)
|
||||
else:
|
||||
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
|
||||
|
||||
self.store_flos()
|
||||
self.store_flos()
|
||||
self.save_model(output_dir)
|
||||
|
||||
# Save optimizer and scheduler
|
||||
@@ -909,6 +924,7 @@ class Trainer:
|
||||
n_trials: int = 20,
|
||||
direction: str = "minimize",
|
||||
backend: Optional[Union["str", HPSearchBackend]] = None,
|
||||
hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
|
||||
**kwargs
|
||||
) -> BestRun:
|
||||
"""
|
||||
@@ -966,13 +982,13 @@ class Trainer:
|
||||
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
|
||||
)
|
||||
self.hp_search_backend = backend
|
||||
|
||||
if self.model_init is None:
|
||||
raise RuntimeError(
|
||||
"To use hyperparameter search, you need to pass your model through a model_init function."
|
||||
)
|
||||
|
||||
self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
|
||||
self.hp_name = hp_name
|
||||
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
|
||||
|
||||
run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
|
||||
@@ -997,12 +1013,12 @@ class Trainer:
|
||||
FutureWarning,
|
||||
)
|
||||
return self._log(logs)
|
||||
|
||||
if self.state.epoch is not None:
|
||||
logs["epoch"] = self.state.epoch
|
||||
if self._total_flos is not None:
|
||||
self.store_flos()
|
||||
logs["total_flos"] = self.state.total_flos
|
||||
|
||||
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
|
||||
output = {**logs, **{"step": self.state.global_step}}
|
||||
self.state.log_history.append(output)
|
||||
|
||||
Reference in New Issue
Block a user