TensorBoard/Wandb/optuna/raytune integration improvements. (#7935)
Improved TensorBoard and Wandb integration, as well as optuna and ray/tune support, with minor modifications to trainer core code.
This commit is contained in:
@@ -85,6 +85,17 @@ def is_ray_available():
|
|||||||
return _has_ray
|
return _has_ray
|
||||||
|
|
||||||
|
|
||||||
|
def hp_params(trial):
|
||||||
|
if is_optuna_available():
|
||||||
|
if isinstance(trial, optuna.Trial):
|
||||||
|
return trial.params
|
||||||
|
if is_ray_available():
|
||||||
|
if isinstance(trial, dict):
|
||||||
|
return trial
|
||||||
|
|
||||||
|
raise RuntimeError(f"Unknown type for trial {trial.__class__}")
|
||||||
|
|
||||||
|
|
||||||
def default_hp_search_backend():
|
def default_hp_search_backend():
|
||||||
if is_optuna_available():
|
if is_optuna_available():
|
||||||
return "optuna"
|
return "optuna"
|
||||||
@@ -192,6 +203,18 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
|
|||||||
return best_run
|
return best_run
|
||||||
|
|
||||||
|
|
||||||
|
def rewrite_logs(d):
|
||||||
|
new_d = {}
|
||||||
|
eval_prefix = "eval_"
|
||||||
|
eval_prefix_len = len(eval_prefix)
|
||||||
|
for k, v in d.items():
|
||||||
|
if k.startswith(eval_prefix):
|
||||||
|
new_d["eval/" + k[eval_prefix_len:]] = v
|
||||||
|
else:
|
||||||
|
new_d["train/" + k] = v
|
||||||
|
return new_d
|
||||||
|
|
||||||
|
|
||||||
class TensorBoardCallback(TrainerCallback):
|
class TensorBoardCallback(TrainerCallback):
|
||||||
"""
|
"""
|
||||||
A :class:`~transformers.TrainerCallback` that sends the logs to `TensorBoard
|
A :class:`~transformers.TrainerCallback` that sends the logs to `TensorBoard
|
||||||
@@ -208,17 +231,39 @@ class TensorBoardCallback(TrainerCallback):
|
|||||||
), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
|
), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
|
||||||
self.tb_writer = tb_writer
|
self.tb_writer = tb_writer
|
||||||
|
|
||||||
def on_init_end(self, args, state, control, **kwargs):
|
def _init_summary_writer(self, args, log_dir=None):
|
||||||
if self.tb_writer is None and state.is_world_process_zero:
|
log_dir = log_dir or args.logging_dir
|
||||||
self.tb_writer = SummaryWriter(log_dir=args.logging_dir)
|
self.tb_writer = SummaryWriter(log_dir=log_dir)
|
||||||
|
|
||||||
def on_train_begin(self, args, state, control, **kwargs):
|
def on_train_begin(self, args, state, control, **kwargs):
|
||||||
|
if not state.is_world_process_zero:
|
||||||
|
return
|
||||||
|
|
||||||
|
log_dir = None
|
||||||
|
|
||||||
|
if state.is_hyper_param_search:
|
||||||
|
trial_name = state.trial_name
|
||||||
|
if trial_name is not None:
|
||||||
|
log_dir = os.path.join(args.logging_dir, trial_name)
|
||||||
|
|
||||||
|
self._init_summary_writer(args, log_dir)
|
||||||
|
|
||||||
if self.tb_writer is not None:
|
if self.tb_writer is not None:
|
||||||
self.tb_writer.add_text("args", args.to_json_string())
|
self.tb_writer.add_text("args", args.to_json_string())
|
||||||
|
if "model" in kwargs:
|
||||||
|
model = kwargs["model"]
|
||||||
|
if hasattr(model, "config") and model.config is not None:
|
||||||
|
model_config_json = model.config.to_json_string()
|
||||||
|
self.tb_writer.add_text("model_config", model_config_json)
|
||||||
self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
|
self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
|
||||||
|
|
||||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||||
|
if state.is_world_process_zero:
|
||||||
|
if self.tb_writer is None:
|
||||||
|
self._init_summary_writer(args)
|
||||||
|
|
||||||
if self.tb_writer:
|
if self.tb_writer:
|
||||||
|
logs = rewrite_logs(logs)
|
||||||
for k, v in logs.items():
|
for k, v in logs.items():
|
||||||
if isinstance(v, (int, float)):
|
if isinstance(v, (int, float)):
|
||||||
self.tb_writer.add_scalar(k, v, state.global_step)
|
self.tb_writer.add_scalar(k, v, state.global_step)
|
||||||
@@ -249,7 +294,7 @@ class WandbCallback(TrainerCallback):
|
|||||||
assert _has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`."
|
assert _has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`."
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
|
||||||
def setup(self, args, state, model):
|
def setup(self, args, state, model, reinit, **kwargs):
|
||||||
"""
|
"""
|
||||||
Setup the optional Weights & Biases (`wandb`) integration.
|
Setup the optional Weights & Biases (`wandb`) integration.
|
||||||
|
|
||||||
@@ -271,21 +316,41 @@ class WandbCallback(TrainerCallback):
|
|||||||
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
|
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
|
||||||
)
|
)
|
||||||
combined_dict = {**args.to_sanitized_dict()}
|
combined_dict = {**args.to_sanitized_dict()}
|
||||||
if getattr(model, "config", None) is not None:
|
|
||||||
combined_dict = {**model.config.to_dict(), **combined_dict}
|
if hasattr(model, "config") and model.config is not None:
|
||||||
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=args.run_name)
|
model_config = model.config.to_dict()
|
||||||
|
combined_dict = {**model_config, **combined_dict}
|
||||||
|
trial_name = state.trial_name
|
||||||
|
init_args = {}
|
||||||
|
if trial_name is not None:
|
||||||
|
run_name = trial_name
|
||||||
|
init_args["group"] = args.run_name
|
||||||
|
else:
|
||||||
|
run_name = args.run_name
|
||||||
|
|
||||||
|
wandb.init(
|
||||||
|
project=os.getenv("WANDB_PROJECT", "huggingface"),
|
||||||
|
config=combined_dict,
|
||||||
|
name=run_name,
|
||||||
|
reinit=reinit,
|
||||||
|
**init_args,
|
||||||
|
)
|
||||||
|
|
||||||
# keep track of model topology and gradients, unsupported on TPU
|
# keep track of model topology and gradients, unsupported on TPU
|
||||||
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
|
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
|
||||||
wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps))
|
wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps))
|
||||||
|
|
||||||
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||||
if not self._initialized:
|
hp_search = state.is_hyper_param_search
|
||||||
self.setup(args, state, model)
|
if not self._initialized or hp_search:
|
||||||
|
print(args.run_name)
|
||||||
|
self.setup(args, state, model, reinit=hp_search, **kwargs)
|
||||||
|
|
||||||
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
|
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
|
||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
self.setup(args, state, model)
|
self.setup(args, state, model, reinit=False)
|
||||||
if state.is_world_process_zero:
|
if state.is_world_process_zero:
|
||||||
|
logs = rewrite_logs(logs)
|
||||||
wandb.log(logs, step=state.global_step)
|
wandb.log(logs, step=state.global_step)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from .file_utils import (
|
|||||||
_torch_available,
|
_torch_available,
|
||||||
_torch_tpu_available,
|
_torch_tpu_available,
|
||||||
)
|
)
|
||||||
|
from .integrations import _has_optuna, _has_ray
|
||||||
|
|
||||||
|
|
||||||
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
|
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
|
||||||
@@ -233,6 +234,32 @@ def require_faiss(test_case):
|
|||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
|
def require_optuna(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires optuna.
|
||||||
|
|
||||||
|
These tests are skipped when optuna isn't installed.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not _has_optuna:
|
||||||
|
return unittest.skip("test requires optuna")(test_case)
|
||||||
|
else:
|
||||||
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
|
def require_ray(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires Ray/tune.
|
||||||
|
|
||||||
|
These tests are skipped when Ray/tune isn't installed.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not _has_ray:
|
||||||
|
return unittest.skip("test requires Ray/tune")(test_case)
|
||||||
|
else:
|
||||||
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
def get_tests_dir(append_path=None):
|
def get_tests_dir(append_path=None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_d
|
|||||||
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available
|
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available
|
||||||
from .integrations import (
|
from .integrations import (
|
||||||
default_hp_search_backend,
|
default_hp_search_backend,
|
||||||
|
hp_params,
|
||||||
is_comet_available,
|
is_comet_available,
|
||||||
is_optuna_available,
|
is_optuna_available,
|
||||||
is_ray_available,
|
is_ray_available,
|
||||||
@@ -224,6 +225,7 @@ class Trainer:
|
|||||||
model is not None or model_init is not None
|
model is not None or model_init is not None
|
||||||
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
|
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
|
||||||
self.model_init = model_init
|
self.model_init = model_init
|
||||||
|
self.hp_name = None
|
||||||
if model is None and model_init is not None:
|
if model is None and model_init is not None:
|
||||||
model = self.call_model_init()
|
model = self.call_model_init()
|
||||||
self.model = model.to(args.device) if model is not None else None
|
self.model = model.to(args.device) if model is not None else None
|
||||||
@@ -508,8 +510,11 @@ class Trainer:
|
|||||||
|
|
||||||
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
|
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
|
||||||
""" HP search setup code """
|
""" HP search setup code """
|
||||||
|
self._trial = trial
|
||||||
|
|
||||||
if self.hp_search_backend is None or trial is None:
|
if self.hp_search_backend is None or trial is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
|
params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
|
||||||
for key, value in params.items():
|
for key, value in params.items():
|
||||||
if not hasattr(self.args, key):
|
if not hasattr(self.args, key):
|
||||||
@@ -558,7 +563,10 @@ class Trainer:
|
|||||||
elif model_init_argcount == 1:
|
elif model_init_argcount == 1:
|
||||||
model = self.model_init(trial)
|
model = self.model_init(trial)
|
||||||
else:
|
else:
|
||||||
raise Exception("model_init should have 0 or 1 argument.")
|
raise RuntimeError("model_init should have 0 or 1 argument.")
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
raise RuntimeError("model_init should not return None.")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@@ -617,6 +625,7 @@ class Trainer:
|
|||||||
|
|
||||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
|
self.state.is_hyper_param_search = trial is not None
|
||||||
|
|
||||||
# Check if saved optimizer or scheduler states exist
|
# Check if saved optimizer or scheduler states exist
|
||||||
if (
|
if (
|
||||||
@@ -702,6 +711,8 @@ class Trainer:
|
|||||||
self.callback_handler.optimizer = self.optimizer
|
self.callback_handler.optimizer = self.optimizer
|
||||||
self.callback_handler.lr_scheduler = self.lr_scheduler
|
self.callback_handler.lr_scheduler = self.lr_scheduler
|
||||||
self.callback_handler.train_dataloader = train_dataloader
|
self.callback_handler.train_dataloader = train_dataloader
|
||||||
|
self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None
|
||||||
|
self.state.trial_params = hp_params(trial) if trial is not None else None
|
||||||
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
|
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
|
||||||
# to set this after the load.
|
# to set this after the load.
|
||||||
self.state.max_steps = max_steps
|
self.state.max_steps = max_steps
|
||||||
@@ -783,13 +794,13 @@ class Trainer:
|
|||||||
self.state.epoch = epoch + (step + 1) / steps_in_epoch
|
self.state.epoch = epoch + (step + 1) / steps_in_epoch
|
||||||
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
|
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
|
||||||
|
|
||||||
self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
|
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
|
||||||
|
|
||||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||||
break
|
break
|
||||||
|
|
||||||
self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
|
self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
|
||||||
self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
|
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
|
||||||
|
|
||||||
if self.args.tpu_metrics_debug or self.args.debug:
|
if self.args.tpu_metrics_debug or self.args.debug:
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
@@ -823,7 +834,7 @@ class Trainer:
|
|||||||
|
|
||||||
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
|
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
|
||||||
|
|
||||||
def _maybe_log_save_evalute(self, tr_loss, model, trial, epoch):
|
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
|
||||||
if self.control.should_log:
|
if self.control.should_log:
|
||||||
logs: Dict[str, float] = {}
|
logs: Dict[str, float] = {}
|
||||||
tr_loss_scalar = tr_loss.item()
|
tr_loss_scalar = tr_loss.item()
|
||||||
@@ -842,6 +853,7 @@ class Trainer:
|
|||||||
if self.control.should_evaluate:
|
if self.control.should_evaluate:
|
||||||
metrics = self.evaluate()
|
metrics = self.evaluate()
|
||||||
self._report_to_hp_search(trial, epoch, metrics)
|
self._report_to_hp_search(trial, epoch, metrics)
|
||||||
|
|
||||||
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
||||||
|
|
||||||
if self.control.should_save:
|
if self.control.should_save:
|
||||||
@@ -857,9 +869,12 @@ class Trainer:
|
|||||||
assert model is self.model, f"Model {model} should be a reference to self.model"
|
assert model is self.model, f"Model {model} should be a reference to self.model"
|
||||||
# Save model checkpoint
|
# Save model checkpoint
|
||||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||||
|
|
||||||
if self.hp_search_backend is not None and trial is not None:
|
if self.hp_search_backend is not None and trial is not None:
|
||||||
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
|
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
|
||||||
checkpoint_folder += f"-run-{run_id}"
|
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
|
||||||
|
output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder)
|
||||||
|
else:
|
||||||
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
|
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
|
||||||
|
|
||||||
self.store_flos()
|
self.store_flos()
|
||||||
@@ -909,6 +924,7 @@ class Trainer:
|
|||||||
n_trials: int = 20,
|
n_trials: int = 20,
|
||||||
direction: str = "minimize",
|
direction: str = "minimize",
|
||||||
backend: Optional[Union["str", HPSearchBackend]] = None,
|
backend: Optional[Union["str", HPSearchBackend]] = None,
|
||||||
|
hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> BestRun:
|
) -> BestRun:
|
||||||
"""
|
"""
|
||||||
@@ -966,13 +982,13 @@ class Trainer:
|
|||||||
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
|
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
|
||||||
)
|
)
|
||||||
self.hp_search_backend = backend
|
self.hp_search_backend = backend
|
||||||
|
|
||||||
if self.model_init is None:
|
if self.model_init is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"To use hyperparameter search, you need to pass your model through a model_init function."
|
"To use hyperparameter search, you need to pass your model through a model_init function."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
|
self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
|
||||||
|
self.hp_name = hp_name
|
||||||
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
|
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
|
||||||
|
|
||||||
run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
|
run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
|
||||||
@@ -997,12 +1013,12 @@ class Trainer:
|
|||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
return self._log(logs)
|
return self._log(logs)
|
||||||
|
|
||||||
if self.state.epoch is not None:
|
if self.state.epoch is not None:
|
||||||
logs["epoch"] = self.state.epoch
|
logs["epoch"] = self.state.epoch
|
||||||
if self._total_flos is not None:
|
if self._total_flos is not None:
|
||||||
self.store_flos()
|
self.store_flos()
|
||||||
logs["total_flos"] = self.state.total_flos
|
logs["total_flos"] = self.state.total_flos
|
||||||
|
|
||||||
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
|
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
|
||||||
output = {**logs, **{"step": self.state.global_step}}
|
output = {**logs, **{"step": self.state.global_step}}
|
||||||
self.state.log_history.append(output)
|
self.state.log_history.append(output)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ Callbacks to use with the Trainer class and customize the training loop.
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
@@ -66,6 +66,9 @@ class TrainerState:
|
|||||||
is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
Whether or not this process is the global main process (when training in a distributed fashion on
|
Whether or not this process is the global main process (when training in a distributed fashion on
|
||||||
several machines, this is only going to be :obj:`True` for one process).
|
several machines, this is only going to be :obj:`True` for one process).
|
||||||
|
is_hyper_param_search (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search.
|
||||||
|
This will impact the way data will be logged in TensorBoard.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
epoch: Optional[float] = None
|
epoch: Optional[float] = None
|
||||||
@@ -78,6 +81,9 @@ class TrainerState:
|
|||||||
best_model_checkpoint: Optional[str] = None
|
best_model_checkpoint: Optional[str] = None
|
||||||
is_local_process_zero: bool = True
|
is_local_process_zero: bool = True
|
||||||
is_world_process_zero: bool = True
|
is_world_process_zero: bool = True
|
||||||
|
is_hyper_param_search: bool = False
|
||||||
|
trial_name: str = None
|
||||||
|
trial_params: Dict[str, Union[str, float, int, bool]] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.log_history is None:
|
if self.log_history is None:
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
Utilities for the Trainer and TFTrainer class. Should be independent from PyTorch and TensorFlow.
|
Utilities for the Trainer and TFTrainer class. Should be independent from PyTorch and TensorFlow.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
import random
|
import random
|
||||||
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
|
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -110,10 +111,15 @@ def default_compute_objective(metrics: Dict[str, float]) -> float:
|
|||||||
Return:
|
Return:
|
||||||
:obj:`float`: The objective to minimize or maximize
|
:obj:`float`: The objective to minimize or maximize
|
||||||
"""
|
"""
|
||||||
|
metrics = copy.deepcopy(metrics)
|
||||||
loss = metrics.pop("eval_loss", None)
|
loss = metrics.pop("eval_loss", None)
|
||||||
_ = metrics.pop("epoch", None)
|
_ = metrics.pop("epoch", None)
|
||||||
_ = metrics.pop("total_flos", None)
|
_ = metrics.pop("total_flos", None)
|
||||||
return loss if len(metrics) == 0 else sum(metrics.values())
|
if len(metrics) != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Metrics contains more entries than just 'eval_loss', 'epoch' and 'total_flos', please provide your own compute_objective function."
|
||||||
|
)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def default_hp_space_optuna(trial) -> Dict[str, float]:
|
def default_hp_space_optuna(trial) -> Dict[str, float]:
|
||||||
|
|||||||
148
src/transformers/utils/hp_naming.py
Normal file
148
src/transformers/utils/hp_naming.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
import copy
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
class TrialShortNamer:
|
||||||
|
PREFIX = "hp"
|
||||||
|
DEFAULTS = {}
|
||||||
|
NAMING_INFO = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_defaults(cls, prefix, defaults):
|
||||||
|
cls.PREFIX = prefix
|
||||||
|
cls.DEFAULTS = defaults
|
||||||
|
cls.build_naming_info()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def shortname_for_word(info, word):
|
||||||
|
if len(word) == 0:
|
||||||
|
return ""
|
||||||
|
short_word = None
|
||||||
|
if any(char.isdigit() for char in word):
|
||||||
|
raise Exception(f"Parameters should not contain numbers: '{word}' contains a number")
|
||||||
|
if word in info["short_word"]:
|
||||||
|
return info["short_word"][word]
|
||||||
|
for prefix_len in range(1, len(word) + 1):
|
||||||
|
prefix = word[:prefix_len]
|
||||||
|
if prefix in info["reverse_short_word"]:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
short_word = prefix
|
||||||
|
break
|
||||||
|
|
||||||
|
if short_word is None:
|
||||||
|
# Paranoid fallback
|
||||||
|
def int_to_alphabetic(integer):
|
||||||
|
s = ""
|
||||||
|
while integer != 0:
|
||||||
|
s = chr(ord("A") + integer % 10) + s
|
||||||
|
integer //= 10
|
||||||
|
return s
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
while True:
|
||||||
|
sword = word + "#" + int_to_alphabetic(i)
|
||||||
|
if sword in info["reverse_short_word"]:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
short_word = sword
|
||||||
|
break
|
||||||
|
|
||||||
|
info["short_word"][word] = short_word
|
||||||
|
info["reverse_short_word"][short_word] = word
|
||||||
|
return short_word
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def shortname_for_key(info, param_name):
|
||||||
|
words = param_name.split("_")
|
||||||
|
|
||||||
|
shortname_parts = [TrialShortNamer.shortname_for_word(info, word) for word in words]
|
||||||
|
|
||||||
|
# We try to create a separatorless short name, but if there is a collision we have to fallback
|
||||||
|
# to a separated short name
|
||||||
|
separators = ["", "_"]
|
||||||
|
|
||||||
|
for separator in separators:
|
||||||
|
shortname = separator.join(shortname_parts)
|
||||||
|
if shortname not in info["reverse_short_param"]:
|
||||||
|
info["short_param"][param_name] = shortname
|
||||||
|
info["reverse_short_param"][shortname] = param_name
|
||||||
|
return shortname
|
||||||
|
|
||||||
|
return param_name
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_new_param_name(info, param_name):
|
||||||
|
short_name = TrialShortNamer.shortname_for_key(info, param_name)
|
||||||
|
info["short_param"][param_name] = short_name
|
||||||
|
info["reverse_short_param"][short_name] = param_name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_naming_info(cls):
|
||||||
|
if cls.NAMING_INFO is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
info = dict(
|
||||||
|
short_word={},
|
||||||
|
reverse_short_word={},
|
||||||
|
short_param={},
|
||||||
|
reverse_short_param={},
|
||||||
|
)
|
||||||
|
|
||||||
|
field_keys = list(cls.DEFAULTS.keys())
|
||||||
|
|
||||||
|
for k in field_keys:
|
||||||
|
cls.add_new_param_name(info, k)
|
||||||
|
|
||||||
|
cls.NAMING_INFO = info
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def shortname(cls, params):
|
||||||
|
cls.build_naming_info()
|
||||||
|
assert cls.PREFIX is not None
|
||||||
|
name = [copy.copy(cls.PREFIX)]
|
||||||
|
|
||||||
|
for k, v in params.items():
|
||||||
|
if k not in cls.DEFAULTS:
|
||||||
|
raise Exception(f"You should provide a default value for the param name {k} with value {v}")
|
||||||
|
if v == cls.DEFAULTS[k]:
|
||||||
|
# The default value is not added to the name
|
||||||
|
continue
|
||||||
|
|
||||||
|
key = cls.NAMING_INFO["short_param"][k]
|
||||||
|
|
||||||
|
if isinstance(v, bool):
|
||||||
|
v = 1 if v else 0
|
||||||
|
|
||||||
|
sep = "" if isinstance(v, (int, float)) else "-"
|
||||||
|
e = f"{key}{sep}{v}"
|
||||||
|
name.append(e)
|
||||||
|
|
||||||
|
return "_".join(name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse_repr(cls, repr):
|
||||||
|
repr = repr[len(cls.PREFIX) + 1 :]
|
||||||
|
if repr == "":
|
||||||
|
values = []
|
||||||
|
else:
|
||||||
|
values = repr.split("_")
|
||||||
|
|
||||||
|
parameters = {}
|
||||||
|
|
||||||
|
for value in values:
|
||||||
|
if "-" in value:
|
||||||
|
p_k, p_v = value.split("-")
|
||||||
|
else:
|
||||||
|
p_k = re.sub("[0-9.]", "", value)
|
||||||
|
p_v = float(re.sub("[^0-9.]", "", value))
|
||||||
|
|
||||||
|
key = cls.NAMING_INFO["reverse_short_param"][p_k]
|
||||||
|
|
||||||
|
parameters[key] = p_v
|
||||||
|
|
||||||
|
for k in cls.DEFAULTS:
|
||||||
|
if k not in parameters:
|
||||||
|
parameters[k] = cls.DEFAULTS[k]
|
||||||
|
|
||||||
|
return parameters
|
||||||
@@ -21,9 +21,17 @@ import unittest
|
|||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers import AutoTokenizer, PretrainedConfig, TrainingArguments, is_torch_available
|
from transformers import AutoTokenizer, EvaluationStrategy, PretrainedConfig, TrainingArguments, is_torch_available
|
||||||
from transformers.file_utils import WEIGHTS_NAME
|
from transformers.file_utils import WEIGHTS_NAME
|
||||||
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, require_torch, slow
|
from transformers.testing_utils import (
|
||||||
|
get_tests_dir,
|
||||||
|
require_optuna,
|
||||||
|
require_sentencepiece,
|
||||||
|
require_tokenizers,
|
||||||
|
require_torch,
|
||||||
|
slow,
|
||||||
|
)
|
||||||
|
from transformers.utils.hp_naming import TrialShortNamer
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -142,6 +150,7 @@ if is_torch_available():
|
|||||||
data_collator = kwargs.pop("data_collator", None)
|
data_collator = kwargs.pop("data_collator", None)
|
||||||
optimizers = kwargs.pop("optimizers", (None, None))
|
optimizers = kwargs.pop("optimizers", (None, None))
|
||||||
output_dir = kwargs.pop("output_dir", "./regression")
|
output_dir = kwargs.pop("output_dir", "./regression")
|
||||||
|
model_init = kwargs.pop("model_init", None)
|
||||||
args = TrainingArguments(output_dir, **kwargs)
|
args = TrainingArguments(output_dir, **kwargs)
|
||||||
return Trainer(
|
return Trainer(
|
||||||
model,
|
model,
|
||||||
@@ -151,6 +160,7 @@ if is_torch_available():
|
|||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
compute_metrics=compute_metrics,
|
compute_metrics=compute_metrics,
|
||||||
optimizers=optimizers,
|
optimizers=optimizers,
|
||||||
|
model_init=model_init,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -617,3 +627,46 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
# with enforced DataParallel
|
# with enforced DataParallel
|
||||||
assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model))
|
assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model))
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_optuna
|
||||||
|
class TrainerHyperParameterIntegrationTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
args = TrainingArguments(".")
|
||||||
|
self.n_epochs = args.num_train_epochs
|
||||||
|
self.batch_size = args.train_batch_size
|
||||||
|
|
||||||
|
def test_hyperparameter_search(self):
|
||||||
|
class MyTrialShortNamer(TrialShortNamer):
|
||||||
|
DEFAULTS = {"a": 0, "b": 0}
|
||||||
|
|
||||||
|
def hp_space(trial):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def model_init(trial):
|
||||||
|
if trial is not None:
|
||||||
|
a = trial.suggest_int("a", -4, 4)
|
||||||
|
b = trial.suggest_int("b", -4, 4)
|
||||||
|
else:
|
||||||
|
a = 0
|
||||||
|
b = 0
|
||||||
|
config = RegressionModelConfig(a=a, b=b, double_output=False)
|
||||||
|
|
||||||
|
return RegressionPreTrainedModel(config)
|
||||||
|
|
||||||
|
def hp_name(trial):
|
||||||
|
return MyTrialShortNamer.shortname(trial.params)
|
||||||
|
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
learning_rate=0.1,
|
||||||
|
logging_steps=1,
|
||||||
|
evaluation_strategy=EvaluationStrategy.EPOCH,
|
||||||
|
num_train_epochs=4,
|
||||||
|
disable_tqdm=True,
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
logging_dir="runs",
|
||||||
|
run_name="test",
|
||||||
|
model_init=model_init,
|
||||||
|
)
|
||||||
|
trainer.hyperparameter_search(direction="minimize", hp_space=hp_space, hp_name=hp_name, n_trials=4)
|
||||||
|
|||||||
Reference in New Issue
Block a user