TensorBoard/Wandb/optuna/raytune integration improvements. (#7935)
Improved TensorBoard and Wandb integration, as well as optuna and ray/tune support, with minor modifications to trainer core code.
This commit is contained in:
@@ -85,6 +85,17 @@ def is_ray_available():
|
||||
return _has_ray
|
||||
|
||||
|
||||
def hp_params(trial):
|
||||
if is_optuna_available():
|
||||
if isinstance(trial, optuna.Trial):
|
||||
return trial.params
|
||||
if is_ray_available():
|
||||
if isinstance(trial, dict):
|
||||
return trial
|
||||
|
||||
raise RuntimeError(f"Unknown type for trial {trial.__class__}")
|
||||
|
||||
|
||||
def default_hp_search_backend():
|
||||
if is_optuna_available():
|
||||
return "optuna"
|
||||
@@ -192,6 +203,18 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
|
||||
return best_run
|
||||
|
||||
|
||||
def rewrite_logs(d):
|
||||
new_d = {}
|
||||
eval_prefix = "eval_"
|
||||
eval_prefix_len = len(eval_prefix)
|
||||
for k, v in d.items():
|
||||
if k.startswith(eval_prefix):
|
||||
new_d["eval/" + k[eval_prefix_len:]] = v
|
||||
else:
|
||||
new_d["train/" + k] = v
|
||||
return new_d
|
||||
|
||||
|
||||
class TensorBoardCallback(TrainerCallback):
|
||||
"""
|
||||
A :class:`~transformers.TrainerCallback` that sends the logs to `TensorBoard
|
||||
@@ -208,17 +231,39 @@ class TensorBoardCallback(TrainerCallback):
|
||||
), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
|
||||
self.tb_writer = tb_writer
|
||||
|
||||
def on_init_end(self, args, state, control, **kwargs):
|
||||
if self.tb_writer is None and state.is_world_process_zero:
|
||||
self.tb_writer = SummaryWriter(log_dir=args.logging_dir)
|
||||
def _init_summary_writer(self, args, log_dir=None):
|
||||
log_dir = log_dir or args.logging_dir
|
||||
self.tb_writer = SummaryWriter(log_dir=log_dir)
|
||||
|
||||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
if not state.is_world_process_zero:
|
||||
return
|
||||
|
||||
log_dir = None
|
||||
|
||||
if state.is_hyper_param_search:
|
||||
trial_name = state.trial_name
|
||||
if trial_name is not None:
|
||||
log_dir = os.path.join(args.logging_dir, trial_name)
|
||||
|
||||
self._init_summary_writer(args, log_dir)
|
||||
|
||||
if self.tb_writer is not None:
|
||||
self.tb_writer.add_text("args", args.to_json_string())
|
||||
if "model" in kwargs:
|
||||
model = kwargs["model"]
|
||||
if hasattr(model, "config") and model.config is not None:
|
||||
model_config_json = model.config.to_json_string()
|
||||
self.tb_writer.add_text("model_config", model_config_json)
|
||||
self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
if self.tb_writer is None:
|
||||
self._init_summary_writer(args)
|
||||
|
||||
if self.tb_writer:
|
||||
logs = rewrite_logs(logs)
|
||||
for k, v in logs.items():
|
||||
if isinstance(v, (int, float)):
|
||||
self.tb_writer.add_scalar(k, v, state.global_step)
|
||||
@@ -249,7 +294,7 @@ class WandbCallback(TrainerCallback):
|
||||
assert _has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`."
|
||||
self._initialized = False
|
||||
|
||||
def setup(self, args, state, model):
|
||||
def setup(self, args, state, model, reinit, **kwargs):
|
||||
"""
|
||||
Setup the optional Weights & Biases (`wandb`) integration.
|
||||
|
||||
@@ -271,21 +316,41 @@ class WandbCallback(TrainerCallback):
|
||||
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
|
||||
)
|
||||
combined_dict = {**args.to_sanitized_dict()}
|
||||
if getattr(model, "config", None) is not None:
|
||||
combined_dict = {**model.config.to_dict(), **combined_dict}
|
||||
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=args.run_name)
|
||||
|
||||
if hasattr(model, "config") and model.config is not None:
|
||||
model_config = model.config.to_dict()
|
||||
combined_dict = {**model_config, **combined_dict}
|
||||
trial_name = state.trial_name
|
||||
init_args = {}
|
||||
if trial_name is not None:
|
||||
run_name = trial_name
|
||||
init_args["group"] = args.run_name
|
||||
else:
|
||||
run_name = args.run_name
|
||||
|
||||
wandb.init(
|
||||
project=os.getenv("WANDB_PROJECT", "huggingface"),
|
||||
config=combined_dict,
|
||||
name=run_name,
|
||||
reinit=reinit,
|
||||
**init_args,
|
||||
)
|
||||
|
||||
# keep track of model topology and gradients, unsupported on TPU
|
||||
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
|
||||
wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps))
|
||||
|
||||
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||
if not self._initialized:
|
||||
self.setup(args, state, model)
|
||||
hp_search = state.is_hyper_param_search
|
||||
if not self._initialized or hp_search:
|
||||
print(args.run_name)
|
||||
self.setup(args, state, model, reinit=hp_search, **kwargs)
|
||||
|
||||
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
|
||||
if not self._initialized:
|
||||
self.setup(args, state, model)
|
||||
self.setup(args, state, model, reinit=False)
|
||||
if state.is_world_process_zero:
|
||||
logs = rewrite_logs(logs)
|
||||
wandb.log(logs, step=state.global_step)
|
||||
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from .file_utils import (
|
||||
_torch_available,
|
||||
_torch_tpu_available,
|
||||
)
|
||||
from .integrations import _has_optuna, _has_ray
|
||||
|
||||
|
||||
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
|
||||
@@ -233,6 +234,32 @@ def require_faiss(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_optuna(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires optuna.
|
||||
|
||||
These tests are skipped when optuna isn't installed.
|
||||
|
||||
"""
|
||||
if not _has_optuna:
|
||||
return unittest.skip("test requires optuna")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_ray(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires Ray/tune.
|
||||
|
||||
These tests are skipped when Ray/tune isn't installed.
|
||||
|
||||
"""
|
||||
if not _has_ray:
|
||||
return unittest.skip("test requires Ray/tune")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def get_tests_dir(append_path=None):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -39,6 +39,7 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_d
|
||||
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available
|
||||
from .integrations import (
|
||||
default_hp_search_backend,
|
||||
hp_params,
|
||||
is_comet_available,
|
||||
is_optuna_available,
|
||||
is_ray_available,
|
||||
@@ -224,6 +225,7 @@ class Trainer:
|
||||
model is not None or model_init is not None
|
||||
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
|
||||
self.model_init = model_init
|
||||
self.hp_name = None
|
||||
if model is None and model_init is not None:
|
||||
model = self.call_model_init()
|
||||
self.model = model.to(args.device) if model is not None else None
|
||||
@@ -508,8 +510,11 @@ class Trainer:
|
||||
|
||||
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
|
||||
""" HP search setup code """
|
||||
self._trial = trial
|
||||
|
||||
if self.hp_search_backend is None or trial is None:
|
||||
return
|
||||
|
||||
params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
|
||||
for key, value in params.items():
|
||||
if not hasattr(self.args, key):
|
||||
@@ -558,7 +563,10 @@ class Trainer:
|
||||
elif model_init_argcount == 1:
|
||||
model = self.model_init(trial)
|
||||
else:
|
||||
raise Exception("model_init should have 0 or 1 argument.")
|
||||
raise RuntimeError("model_init should have 0 or 1 argument.")
|
||||
|
||||
if model is None:
|
||||
raise RuntimeError("model_init should not return None.")
|
||||
|
||||
return model
|
||||
|
||||
@@ -617,6 +625,7 @@ class Trainer:
|
||||
|
||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||
self.state = TrainerState()
|
||||
self.state.is_hyper_param_search = trial is not None
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
if (
|
||||
@@ -702,6 +711,8 @@ class Trainer:
|
||||
self.callback_handler.optimizer = self.optimizer
|
||||
self.callback_handler.lr_scheduler = self.lr_scheduler
|
||||
self.callback_handler.train_dataloader = train_dataloader
|
||||
self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None
|
||||
self.state.trial_params = hp_params(trial) if trial is not None else None
|
||||
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
|
||||
# to set this after the load.
|
||||
self.state.max_steps = max_steps
|
||||
@@ -783,13 +794,13 @@ class Trainer:
|
||||
self.state.epoch = epoch + (step + 1) / steps_in_epoch
|
||||
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
|
||||
|
||||
self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
|
||||
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
|
||||
|
||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||
break
|
||||
|
||||
self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
|
||||
self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
|
||||
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
|
||||
|
||||
if self.args.tpu_metrics_debug or self.args.debug:
|
||||
if is_torch_tpu_available():
|
||||
@@ -823,7 +834,7 @@ class Trainer:
|
||||
|
||||
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
|
||||
|
||||
def _maybe_log_save_evalute(self, tr_loss, model, trial, epoch):
|
||||
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
|
||||
if self.control.should_log:
|
||||
logs: Dict[str, float] = {}
|
||||
tr_loss_scalar = tr_loss.item()
|
||||
@@ -842,6 +853,7 @@ class Trainer:
|
||||
if self.control.should_evaluate:
|
||||
metrics = self.evaluate()
|
||||
self._report_to_hp_search(trial, epoch, metrics)
|
||||
|
||||
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
||||
|
||||
if self.control.should_save:
|
||||
@@ -857,12 +869,15 @@ class Trainer:
|
||||
assert model is self.model, f"Model {model} should be a reference to self.model"
|
||||
# Save model checkpoint
|
||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||
|
||||
if self.hp_search_backend is not None and trial is not None:
|
||||
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
|
||||
checkpoint_folder += f"-run-{run_id}"
|
||||
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
|
||||
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
|
||||
output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder)
|
||||
else:
|
||||
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
|
||||
|
||||
self.store_flos()
|
||||
self.store_flos()
|
||||
self.save_model(output_dir)
|
||||
|
||||
# Save optimizer and scheduler
|
||||
@@ -909,6 +924,7 @@ class Trainer:
|
||||
n_trials: int = 20,
|
||||
direction: str = "minimize",
|
||||
backend: Optional[Union["str", HPSearchBackend]] = None,
|
||||
hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
|
||||
**kwargs
|
||||
) -> BestRun:
|
||||
"""
|
||||
@@ -966,13 +982,13 @@ class Trainer:
|
||||
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
|
||||
)
|
||||
self.hp_search_backend = backend
|
||||
|
||||
if self.model_init is None:
|
||||
raise RuntimeError(
|
||||
"To use hyperparameter search, you need to pass your model through a model_init function."
|
||||
)
|
||||
|
||||
self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
|
||||
self.hp_name = hp_name
|
||||
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
|
||||
|
||||
run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
|
||||
@@ -997,12 +1013,12 @@ class Trainer:
|
||||
FutureWarning,
|
||||
)
|
||||
return self._log(logs)
|
||||
|
||||
if self.state.epoch is not None:
|
||||
logs["epoch"] = self.state.epoch
|
||||
if self._total_flos is not None:
|
||||
self.store_flos()
|
||||
logs["total_flos"] = self.state.total_flos
|
||||
|
||||
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
|
||||
output = {**logs, **{"step": self.state.global_step}}
|
||||
self.state.log_history.append(output)
|
||||
|
||||
@@ -19,7 +19,7 @@ Callbacks to use with the Trainer class and customize the training loop.
|
||||
import dataclasses
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
@@ -66,6 +66,9 @@ class TrainerState:
|
||||
is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not this process is the global main process (when training in a distributed fashion on
|
||||
several machines, this is only going to be :obj:`True` for one process).
|
||||
is_hyper_param_search (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search.
|
||||
This will impact the way data will be logged in TensorBoard.
|
||||
"""
|
||||
|
||||
epoch: Optional[float] = None
|
||||
@@ -78,6 +81,9 @@ class TrainerState:
|
||||
best_model_checkpoint: Optional[str] = None
|
||||
is_local_process_zero: bool = True
|
||||
is_world_process_zero: bool = True
|
||||
is_hyper_param_search: bool = False
|
||||
trial_name: str = None
|
||||
trial_params: Dict[str, Union[str, float, int, bool]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.log_history is None:
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
Utilities for the Trainer and TFTrainer class. Should be independent from PyTorch and TensorFlow.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import random
|
||||
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
@@ -110,10 +111,15 @@ def default_compute_objective(metrics: Dict[str, float]) -> float:
|
||||
Return:
|
||||
:obj:`float`: The objective to minimize or maximize
|
||||
"""
|
||||
metrics = copy.deepcopy(metrics)
|
||||
loss = metrics.pop("eval_loss", None)
|
||||
_ = metrics.pop("epoch", None)
|
||||
_ = metrics.pop("total_flos", None)
|
||||
return loss if len(metrics) == 0 else sum(metrics.values())
|
||||
if len(metrics) != 0:
|
||||
raise RuntimeError(
|
||||
"Metrics contains more entries than just 'eval_loss', 'epoch' and 'total_flos', please provide your own compute_objective function."
|
||||
)
|
||||
return loss
|
||||
|
||||
|
||||
def default_hp_space_optuna(trial) -> Dict[str, float]:
|
||||
|
||||
148
src/transformers/utils/hp_naming.py
Normal file
148
src/transformers/utils/hp_naming.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import copy
|
||||
import re
|
||||
|
||||
|
||||
class TrialShortNamer:
|
||||
PREFIX = "hp"
|
||||
DEFAULTS = {}
|
||||
NAMING_INFO = None
|
||||
|
||||
@classmethod
|
||||
def set_defaults(cls, prefix, defaults):
|
||||
cls.PREFIX = prefix
|
||||
cls.DEFAULTS = defaults
|
||||
cls.build_naming_info()
|
||||
|
||||
@staticmethod
|
||||
def shortname_for_word(info, word):
|
||||
if len(word) == 0:
|
||||
return ""
|
||||
short_word = None
|
||||
if any(char.isdigit() for char in word):
|
||||
raise Exception(f"Parameters should not contain numbers: '{word}' contains a number")
|
||||
if word in info["short_word"]:
|
||||
return info["short_word"][word]
|
||||
for prefix_len in range(1, len(word) + 1):
|
||||
prefix = word[:prefix_len]
|
||||
if prefix in info["reverse_short_word"]:
|
||||
continue
|
||||
else:
|
||||
short_word = prefix
|
||||
break
|
||||
|
||||
if short_word is None:
|
||||
# Paranoid fallback
|
||||
def int_to_alphabetic(integer):
|
||||
s = ""
|
||||
while integer != 0:
|
||||
s = chr(ord("A") + integer % 10) + s
|
||||
integer //= 10
|
||||
return s
|
||||
|
||||
i = 0
|
||||
while True:
|
||||
sword = word + "#" + int_to_alphabetic(i)
|
||||
if sword in info["reverse_short_word"]:
|
||||
continue
|
||||
else:
|
||||
short_word = sword
|
||||
break
|
||||
|
||||
info["short_word"][word] = short_word
|
||||
info["reverse_short_word"][short_word] = word
|
||||
return short_word
|
||||
|
||||
@staticmethod
|
||||
def shortname_for_key(info, param_name):
|
||||
words = param_name.split("_")
|
||||
|
||||
shortname_parts = [TrialShortNamer.shortname_for_word(info, word) for word in words]
|
||||
|
||||
# We try to create a separatorless short name, but if there is a collision we have to fallback
|
||||
# to a separated short name
|
||||
separators = ["", "_"]
|
||||
|
||||
for separator in separators:
|
||||
shortname = separator.join(shortname_parts)
|
||||
if shortname not in info["reverse_short_param"]:
|
||||
info["short_param"][param_name] = shortname
|
||||
info["reverse_short_param"][shortname] = param_name
|
||||
return shortname
|
||||
|
||||
return param_name
|
||||
|
||||
@staticmethod
|
||||
def add_new_param_name(info, param_name):
|
||||
short_name = TrialShortNamer.shortname_for_key(info, param_name)
|
||||
info["short_param"][param_name] = short_name
|
||||
info["reverse_short_param"][short_name] = param_name
|
||||
|
||||
@classmethod
|
||||
def build_naming_info(cls):
|
||||
if cls.NAMING_INFO is not None:
|
||||
return
|
||||
|
||||
info = dict(
|
||||
short_word={},
|
||||
reverse_short_word={},
|
||||
short_param={},
|
||||
reverse_short_param={},
|
||||
)
|
||||
|
||||
field_keys = list(cls.DEFAULTS.keys())
|
||||
|
||||
for k in field_keys:
|
||||
cls.add_new_param_name(info, k)
|
||||
|
||||
cls.NAMING_INFO = info
|
||||
|
||||
@classmethod
|
||||
def shortname(cls, params):
|
||||
cls.build_naming_info()
|
||||
assert cls.PREFIX is not None
|
||||
name = [copy.copy(cls.PREFIX)]
|
||||
|
||||
for k, v in params.items():
|
||||
if k not in cls.DEFAULTS:
|
||||
raise Exception(f"You should provide a default value for the param name {k} with value {v}")
|
||||
if v == cls.DEFAULTS[k]:
|
||||
# The default value is not added to the name
|
||||
continue
|
||||
|
||||
key = cls.NAMING_INFO["short_param"][k]
|
||||
|
||||
if isinstance(v, bool):
|
||||
v = 1 if v else 0
|
||||
|
||||
sep = "" if isinstance(v, (int, float)) else "-"
|
||||
e = f"{key}{sep}{v}"
|
||||
name.append(e)
|
||||
|
||||
return "_".join(name)
|
||||
|
||||
@classmethod
|
||||
def parse_repr(cls, repr):
|
||||
repr = repr[len(cls.PREFIX) + 1 :]
|
||||
if repr == "":
|
||||
values = []
|
||||
else:
|
||||
values = repr.split("_")
|
||||
|
||||
parameters = {}
|
||||
|
||||
for value in values:
|
||||
if "-" in value:
|
||||
p_k, p_v = value.split("-")
|
||||
else:
|
||||
p_k = re.sub("[0-9.]", "", value)
|
||||
p_v = float(re.sub("[^0-9.]", "", value))
|
||||
|
||||
key = cls.NAMING_INFO["reverse_short_param"][p_k]
|
||||
|
||||
parameters[key] = p_v
|
||||
|
||||
for k in cls.DEFAULTS:
|
||||
if k not in parameters:
|
||||
parameters[k] = cls.DEFAULTS[k]
|
||||
|
||||
return parameters
|
||||
@@ -21,9 +21,17 @@ import unittest
|
||||
import datasets
|
||||
import numpy as np
|
||||
|
||||
from transformers import AutoTokenizer, PretrainedConfig, TrainingArguments, is_torch_available
|
||||
from transformers import AutoTokenizer, EvaluationStrategy, PretrainedConfig, TrainingArguments, is_torch_available
|
||||
from transformers.file_utils import WEIGHTS_NAME
|
||||
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, require_torch, slow
|
||||
from transformers.testing_utils import (
|
||||
get_tests_dir,
|
||||
require_optuna,
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
slow,
|
||||
)
|
||||
from transformers.utils.hp_naming import TrialShortNamer
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -142,6 +150,7 @@ if is_torch_available():
|
||||
data_collator = kwargs.pop("data_collator", None)
|
||||
optimizers = kwargs.pop("optimizers", (None, None))
|
||||
output_dir = kwargs.pop("output_dir", "./regression")
|
||||
model_init = kwargs.pop("model_init", None)
|
||||
args = TrainingArguments(output_dir, **kwargs)
|
||||
return Trainer(
|
||||
model,
|
||||
@@ -151,6 +160,7 @@ if is_torch_available():
|
||||
eval_dataset=eval_dataset,
|
||||
compute_metrics=compute_metrics,
|
||||
optimizers=optimizers,
|
||||
model_init=model_init,
|
||||
)
|
||||
|
||||
|
||||
@@ -617,3 +627,46 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
|
||||
# with enforced DataParallel
|
||||
assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model))
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_optuna
|
||||
class TrainerHyperParameterIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
args = TrainingArguments(".")
|
||||
self.n_epochs = args.num_train_epochs
|
||||
self.batch_size = args.train_batch_size
|
||||
|
||||
def test_hyperparameter_search(self):
|
||||
class MyTrialShortNamer(TrialShortNamer):
|
||||
DEFAULTS = {"a": 0, "b": 0}
|
||||
|
||||
def hp_space(trial):
|
||||
return {}
|
||||
|
||||
def model_init(trial):
|
||||
if trial is not None:
|
||||
a = trial.suggest_int("a", -4, 4)
|
||||
b = trial.suggest_int("b", -4, 4)
|
||||
else:
|
||||
a = 0
|
||||
b = 0
|
||||
config = RegressionModelConfig(a=a, b=b, double_output=False)
|
||||
|
||||
return RegressionPreTrainedModel(config)
|
||||
|
||||
def hp_name(trial):
|
||||
return MyTrialShortNamer.shortname(trial.params)
|
||||
|
||||
trainer = get_regression_trainer(
|
||||
learning_rate=0.1,
|
||||
logging_steps=1,
|
||||
evaluation_strategy=EvaluationStrategy.EPOCH,
|
||||
num_train_epochs=4,
|
||||
disable_tqdm=True,
|
||||
load_best_model_at_end=True,
|
||||
logging_dir="runs",
|
||||
run_name="test",
|
||||
model_init=model_init,
|
||||
)
|
||||
trainer.hyperparameter_search(direction="minimize", hp_space=hp_space, hp_name=hp_name, n_trials=4)
|
||||
|
||||
Reference in New Issue
Block a user