Add hyperparameter search to Trainer (#6576)
* Add optuna hyperparameter search to Trainer * @julien-c suggestions Co-authored-by: Julien Chaumond <chaumond@gmail.com> * Make compute_objective an arg function * Formatting * Rework to make it easier to add ray * Formatting * Initial support for Ray * Formatting * Polish and finalize * Add trial id to checkpoint with Ray * Smaller default * Use GPU in ray if available * Formatting * Fix test * Update install instruction Co-authored-by: Richard Liaw <rliaw@berkeley.edu> * Address review comments * Formatting post-merge Co-authored-by: Julien Chaumond <chaumond@gmail.com> Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
@@ -92,7 +92,13 @@ from .file_utils import (
|
|||||||
from .hf_argparser import HfArgumentParser
|
from .hf_argparser import HfArgumentParser
|
||||||
|
|
||||||
# Integrations
|
# Integrations
|
||||||
from .integrations import is_comet_available, is_tensorboard_available, is_wandb_available
|
from .integrations import (
|
||||||
|
is_comet_available,
|
||||||
|
is_optuna_available,
|
||||||
|
is_ray_available,
|
||||||
|
is_tensorboard_available,
|
||||||
|
is_wandb_available,
|
||||||
|
)
|
||||||
|
|
||||||
# Model Cards
|
# Model Cards
|
||||||
from .modelcard import ModelCard
|
from .modelcard import ModelCard
|
||||||
|
|||||||
@@ -35,6 +35,20 @@ except ImportError:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
_has_tensorboard = False
|
_has_tensorboard = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import optuna # noqa: F401
|
||||||
|
|
||||||
|
_has_optuna = True
|
||||||
|
except (ImportError):
|
||||||
|
_has_optuna = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ray # noqa: F401
|
||||||
|
|
||||||
|
_has_ray = True
|
||||||
|
except (ImportError):
|
||||||
|
_has_ray = False
|
||||||
|
|
||||||
|
|
||||||
def is_wandb_available():
|
def is_wandb_available():
|
||||||
return _has_wandb
|
return _has_wandb
|
||||||
@@ -46,3 +60,18 @@ def is_comet_available():
|
|||||||
|
|
||||||
def is_tensorboard_available():
|
def is_tensorboard_available():
|
||||||
return _has_tensorboard
|
return _has_tensorboard
|
||||||
|
|
||||||
|
|
||||||
|
def is_optuna_available():
|
||||||
|
return _has_optuna
|
||||||
|
|
||||||
|
|
||||||
|
def is_ray_available():
|
||||||
|
return _has_ray
|
||||||
|
|
||||||
|
|
||||||
|
def default_hp_search_backend():
|
||||||
|
if is_optuna_available():
|
||||||
|
return "optuna"
|
||||||
|
elif is_ray_available():
|
||||||
|
return "ray"
|
||||||
|
|||||||
@@ -21,10 +21,27 @@ from tqdm.auto import tqdm, trange
|
|||||||
|
|
||||||
from .data.data_collator import DataCollator, default_data_collator
|
from .data.data_collator import DataCollator, default_data_collator
|
||||||
from .file_utils import is_nlp_available, is_torch_tpu_available
|
from .file_utils import is_nlp_available, is_torch_tpu_available
|
||||||
from .integrations import is_comet_available, is_tensorboard_available, is_wandb_available
|
from .integrations import (
|
||||||
|
default_hp_search_backend,
|
||||||
|
is_comet_available,
|
||||||
|
is_optuna_available,
|
||||||
|
is_ray_available,
|
||||||
|
is_tensorboard_available,
|
||||||
|
is_wandb_available,
|
||||||
|
)
|
||||||
from .modeling_utils import PreTrainedModel
|
from .modeling_utils import PreTrainedModel
|
||||||
from .optimization import AdamW, get_linear_schedule_with_warmup
|
from .optimization import AdamW, get_linear_schedule_with_warmup
|
||||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput, set_seed
|
from .trainer_utils import (
|
||||||
|
PREFIX_CHECKPOINT_DIR,
|
||||||
|
BestRun,
|
||||||
|
EvalPrediction,
|
||||||
|
HPSearchBackend,
|
||||||
|
PredictionOutput,
|
||||||
|
TrainOutput,
|
||||||
|
default_compute_objective,
|
||||||
|
default_hp_space,
|
||||||
|
set_seed,
|
||||||
|
)
|
||||||
from .training_args import TrainingArguments
|
from .training_args import TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
@@ -62,6 +79,12 @@ if is_wandb_available():
|
|||||||
if is_comet_available():
|
if is_comet_available():
|
||||||
import comet_ml
|
import comet_ml
|
||||||
|
|
||||||
|
if is_optuna_available():
|
||||||
|
import optuna
|
||||||
|
|
||||||
|
if is_ray_available():
|
||||||
|
from ray import tune
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -140,10 +163,11 @@ class Trainer:
|
|||||||
optimized for 🤗 Transformers.
|
optimized for 🤗 Transformers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (:class:`~transformers.PreTrainedModel`):
|
model (:class:`~transformers.PreTrainedModel`, `optional`):
|
||||||
The model to train, evaluate or use for predictions.
|
The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
|
||||||
args (:class:`~transformers.TrainingArguments`):
|
args (:class:`~transformers.TrainingArguments`, `optional`):
|
||||||
The arguments to tweak for training.
|
The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
|
||||||
|
with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
|
||||||
data_collator (:obj:`DataCollator`, `optional`, defaults to :func:`~transformers.default_data_collator`):
|
data_collator (:obj:`DataCollator`, `optional`, defaults to :func:`~transformers.default_data_collator`):
|
||||||
The function to use to form a batch from a list of elements of :obj:`train_dataset` or
|
The function to use to form a batch from a list of elements of :obj:`train_dataset` or
|
||||||
:obj:`eval_dataset`.
|
:obj:`eval_dataset`.
|
||||||
@@ -153,6 +177,9 @@ class Trainer:
|
|||||||
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
||||||
The dataset to use for evaluation. If it is an :obj:`nlp.Dataset`, columns not accepted by the
|
The dataset to use for evaluation. If it is an :obj:`nlp.Dataset`, columns not accepted by the
|
||||||
``model.forward()`` method are automatically removed.
|
``model.forward()`` method are automatically removed.
|
||||||
|
model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
|
||||||
|
A function that instantiates the model to be used. If provided, each call to
|
||||||
|
:meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
|
||||||
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
|
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
|
||||||
The function that will be used to compute metrics at evaluation. Must take a
|
The function that will be used to compute metrics at evaluation. Must take a
|
||||||
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
|
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
|
||||||
@@ -168,21 +195,31 @@ class Trainer:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: PreTrainedModel,
|
model: PreTrainedModel = None,
|
||||||
args: TrainingArguments,
|
args: TrainingArguments = None,
|
||||||
data_collator: Optional[DataCollator] = None,
|
data_collator: Optional[DataCollator] = None,
|
||||||
train_dataset: Optional[Dataset] = None,
|
train_dataset: Optional[Dataset] = None,
|
||||||
eval_dataset: Optional[Dataset] = None,
|
eval_dataset: Optional[Dataset] = None,
|
||||||
|
model_init: Callable[[], PreTrainedModel] = None,
|
||||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||||
tb_writer: Optional["SummaryWriter"] = None,
|
tb_writer: Optional["SummaryWriter"] = None,
|
||||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.model = model.to(args.device)
|
assert (
|
||||||
|
model is not None or model_init is not None
|
||||||
|
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
|
||||||
|
if model is None and model_init is not None:
|
||||||
|
model = model_init()
|
||||||
|
self.model = model.to(args.device) if model is not None else None
|
||||||
|
if args is None:
|
||||||
|
logger.info("No `TrainingArguments` passed, using the current path as `output_dir`.")
|
||||||
|
args = TrainingArguments("tmp_trainer")
|
||||||
self.args = args
|
self.args = args
|
||||||
self.data_collator = data_collator if data_collator is not None else default_data_collator
|
self.data_collator = data_collator if data_collator is not None else default_data_collator
|
||||||
self.train_dataset = train_dataset
|
self.train_dataset = train_dataset
|
||||||
self.eval_dataset = eval_dataset
|
self.eval_dataset = eval_dataset
|
||||||
|
self.model_init = model_init
|
||||||
self.compute_metrics = compute_metrics
|
self.compute_metrics = compute_metrics
|
||||||
self.optimizer, self.lr_scheduler = optimizers
|
self.optimizer, self.lr_scheduler = optimizers
|
||||||
self.tb_writer = tb_writer
|
self.tb_writer = tb_writer
|
||||||
@@ -242,6 +279,7 @@ class Trainer:
|
|||||||
self.epoch = None
|
self.epoch = None
|
||||||
if self.args.fp16 and _use_native_amp:
|
if self.args.fp16 and _use_native_amp:
|
||||||
self.scaler = torch.cuda.amp.GradScaler()
|
self.scaler = torch.cuda.amp.GradScaler()
|
||||||
|
self.hp_search_backend = None
|
||||||
|
|
||||||
def _remove_unused_columns(self, dataset: "nlp.Dataset", description: Optional[str] = None):
|
def _remove_unused_columns(self, dataset: "nlp.Dataset", description: Optional[str] = None):
|
||||||
if not self.args.remove_unused_columns:
|
if not self.args.remove_unused_columns:
|
||||||
@@ -462,7 +500,38 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
return len(dataloader.dataset)
|
return len(dataloader.dataset)
|
||||||
|
|
||||||
def train(self, model_path: Optional[str] = None):
|
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
|
||||||
|
""" HP search setup code """
|
||||||
|
if self.hp_search_backend is None or trial is None:
|
||||||
|
return
|
||||||
|
params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
|
||||||
|
for key, value in params.items():
|
||||||
|
if not hasattr(self.args, key):
|
||||||
|
raise AttributeError(
|
||||||
|
f"Trying to set {key} in the hyperparameter search but there is no corresponding field in `TrainingArguments`."
|
||||||
|
)
|
||||||
|
old_attr = getattr(self.args, key, None)
|
||||||
|
# Casting value to the proper type
|
||||||
|
if old_attr is not None:
|
||||||
|
value = type(old_attr)(value)
|
||||||
|
setattr(self.args, key, value)
|
||||||
|
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||||
|
logger.info("Trial:", trial.params)
|
||||||
|
|
||||||
|
def _report_to_hp_search(
|
||||||
|
self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
|
||||||
|
):
|
||||||
|
if self.hp_search_backend is None or trial is None:
|
||||||
|
return
|
||||||
|
self.objective = self.compute_objective(metrics)
|
||||||
|
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||||
|
trial.report(self.objective, epoch)
|
||||||
|
if trial.should_prune():
|
||||||
|
raise optuna.TrialPruned()
|
||||||
|
elif self.hp_search_backend == HPSearchBackend.RAY:
|
||||||
|
tune.report(objective=self.objective, **metrics)
|
||||||
|
|
||||||
|
def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
|
||||||
"""
|
"""
|
||||||
Main training entry point.
|
Main training entry point.
|
||||||
|
|
||||||
@@ -470,7 +539,17 @@ class Trainer:
|
|||||||
model_path (:obj:`str`, `optional`):
|
model_path (:obj:`str`, `optional`):
|
||||||
Local path to the model if the model to train has been instantiated from a local path. If present,
|
Local path to the model if the model to train has been instantiated from a local path. If present,
|
||||||
training will resume from the optimizer/scheduler states loaded here.
|
training will resume from the optimizer/scheduler states loaded here.
|
||||||
|
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
|
||||||
|
The trial run or the hyperparameter dictionary for hyperparameter search.
|
||||||
"""
|
"""
|
||||||
|
# Model re-init
|
||||||
|
if self.model_init is not None:
|
||||||
|
model = self.model_init()
|
||||||
|
self.model = model.to(self.args.device)
|
||||||
|
|
||||||
|
self._hp_search_setup(trial)
|
||||||
|
|
||||||
|
# Data loader and number of training steps
|
||||||
train_dataloader = self.get_train_dataloader()
|
train_dataloader = self.get_train_dataloader()
|
||||||
if self.args.max_steps > 0:
|
if self.args.max_steps > 0:
|
||||||
t_total = self.args.max_steps
|
t_total = self.args.max_steps
|
||||||
@@ -561,9 +640,8 @@ class Trainer:
|
|||||||
tr_loss = 0.0
|
tr_loss = 0.0
|
||||||
logging_loss = 0.0
|
logging_loss = 0.0
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
train_iterator = trange(
|
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
|
||||||
epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=not self.is_local_process_zero()
|
train_iterator = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
|
||||||
)
|
|
||||||
for epoch in train_iterator:
|
for epoch in train_iterator:
|
||||||
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
||||||
train_dataloader.sampler.set_epoch(epoch)
|
train_dataloader.sampler.set_epoch(epoch)
|
||||||
@@ -572,9 +650,9 @@ class Trainer:
|
|||||||
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
|
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
|
||||||
self.args.device
|
self.args.device
|
||||||
)
|
)
|
||||||
epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_process_zero())
|
epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=disable_tqdm)
|
||||||
else:
|
else:
|
||||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_process_zero())
|
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=disable_tqdm)
|
||||||
|
|
||||||
# Reset the past mems state at the beginning of each epoch if necessary.
|
# Reset the past mems state at the beginning of each epoch if necessary.
|
||||||
if self.args.past_index >= 0:
|
if self.args.past_index >= 0:
|
||||||
@@ -631,7 +709,8 @@ class Trainer:
|
|||||||
self.log(logs)
|
self.log(logs)
|
||||||
|
|
||||||
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
|
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
|
||||||
self.evaluate()
|
metrics = self.evaluate()
|
||||||
|
self._report_to_hp_search(trial, epoch, metrics)
|
||||||
|
|
||||||
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
|
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
|
||||||
# In all cases (even distributed/parallel), self.model is always a reference
|
# In all cases (even distributed/parallel), self.model is always a reference
|
||||||
@@ -643,7 +722,15 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
assert model is self.model, f"Model {model} should be a reference to self.model"
|
assert model is self.model, f"Model {model} should be a reference to self.model"
|
||||||
# Save model checkpoint
|
# Save model checkpoint
|
||||||
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
|
||||||
|
if self.hp_search_backend is not None and trial is not None:
|
||||||
|
run_id = (
|
||||||
|
trial.number
|
||||||
|
if self.hp_search_backend == HPSearchBackend.OPTUNA
|
||||||
|
else tune.get_trial_id()
|
||||||
|
)
|
||||||
|
checkpoint_folder += f"-run-{run_id}"
|
||||||
|
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
|
||||||
|
|
||||||
self.save_model(output_dir)
|
self.save_model(output_dir)
|
||||||
|
|
||||||
@@ -683,6 +770,108 @@ class Trainer:
|
|||||||
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
|
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
|
||||||
return TrainOutput(self.global_step, tr_loss / self.global_step)
|
return TrainOutput(self.global_step, tr_loss / self.global_step)
|
||||||
|
|
||||||
|
def hyperparameter_search(
|
||||||
|
self,
|
||||||
|
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
|
||||||
|
compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
|
||||||
|
n_trials: int = 20,
|
||||||
|
timeout: int = 1800,
|
||||||
|
n_jobs: int = 1,
|
||||||
|
direction: str = "minimize",
|
||||||
|
backend: Optional[Union["str", HPSearchBackend]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> BestRun:
|
||||||
|
"""
|
||||||
|
Launch an hyperparameter search using ``optuna`` or ``Ray Tune``. The optimized quantity is determined by the
|
||||||
|
method, which is the evaluation loss when no metric is provided, the sum of all metrics otherwise (you can
|
||||||
|
change that behavior by subclassing and overriding this method).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hp_space (:obj:`Callable[["optuna.Trial"], Dict[str, float]]`, `optional`):
|
||||||
|
A function that defines the hyperparameter search space. Will default to
|
||||||
|
:func:`~transformers.trainer_utils.default_hp_space_optuna` or
|
||||||
|
:func:`~transformers.trainer_utils.default_hp_space_ray` depending on your backend.
|
||||||
|
compute_objective (:obj:`Callable[[Dict[str, float]], float]`, `optional`):
|
||||||
|
A function computing the objective to minimize or maximize from the metrics returned by the
|
||||||
|
:obj:`evaluate` method. Will default to :func:`~transformers.trainer_utils.default_compute_objective`.
|
||||||
|
n_trials (:obj:`int`, `optional`, defaults to 100):
|
||||||
|
The number of trial runs to test.
|
||||||
|
direction(:obj:`str`, `optional`, defaults to :obj:`"minimize"`):
|
||||||
|
Whether to optimize greater or lower objects. Can be :obj:`"minimize"` or :obj:`"maximize"`, you should
|
||||||
|
pick :obj:`"minimize"` when optimizing the validation loss, :obj:`"maximize"` when optimizing one or
|
||||||
|
several metrics.
|
||||||
|
backend(:obj:`str` or :class:`~transformers.training_utils.HPSearchBackend`, `optional`):
|
||||||
|
The backend to use for hyperparameter search. Will default to optuna or Ray Tune, depending on which
|
||||||
|
one is installed. If both are installed, will default to optuna.
|
||||||
|
kwargs:
|
||||||
|
Additional keyword arguments passed along to :obj:`optuna.create_study` or :obj:`ray.tune.run`. For
|
||||||
|
more information see:
|
||||||
|
|
||||||
|
- the documentation of `optuna.create_stufy <https://optuna.readthedocs.io/en/stable/reference/alias_generated/optuna.create_study.html#optuna.create_study>`__
|
||||||
|
- the documentation of `tune.run <https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run>`__
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:class:`transformers.trainer_utils.BestRun`: All the informations about the best run.
|
||||||
|
"""
|
||||||
|
if backend is None:
|
||||||
|
backend = default_hp_search_backend()
|
||||||
|
if backend is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"At least one of optuna or ray should be installed. "
|
||||||
|
"To install optuna run `pip install optuna`."
|
||||||
|
"To install ray run `pip install ray[tune]`."
|
||||||
|
)
|
||||||
|
backend = HPSearchBackend(backend)
|
||||||
|
if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
|
||||||
|
raise RuntimeError(" You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
|
||||||
|
if backend == HPSearchBackend.RAY and not is_ray_available():
|
||||||
|
raise RuntimeError(
|
||||||
|
" You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
|
||||||
|
)
|
||||||
|
self.hp_search_backend = backend
|
||||||
|
|
||||||
|
self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
|
||||||
|
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
|
||||||
|
|
||||||
|
def _objective(trial):
|
||||||
|
# To make sure optimizer and lr_scheduler are reset with the new choices of HPs
|
||||||
|
self.optimizer = None
|
||||||
|
self.lr_scheduler = None
|
||||||
|
self.objective = None
|
||||||
|
self.train(trial=trial)
|
||||||
|
# If there hasn't been any evaluation during the training loop.
|
||||||
|
if getattr(self, "objective", None) is None:
|
||||||
|
metrics = self.evaluate()
|
||||||
|
self.objective = self.compute_objective(metrics)
|
||||||
|
return self.objective
|
||||||
|
|
||||||
|
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||||
|
timeout = kwargs.pop("timeout", None)
|
||||||
|
n_jobs = kwargs.pop("n_jobs", 1)
|
||||||
|
study = optuna.create_study(direction=direction, **kwargs)
|
||||||
|
study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
|
||||||
|
best_trial = study.best_trial
|
||||||
|
best_run = BestRun(str(best_trial.number), best_trial.value, best_trial.params)
|
||||||
|
elif self.hp_search_backend == HPSearchBackend.RAY:
|
||||||
|
# The TensorBoard writer does not pickle so we have to remove it (if it exists) while doing the ray hp
|
||||||
|
# search.
|
||||||
|
_tb_writer = self.tb_writer
|
||||||
|
self.tb_writer = None
|
||||||
|
# Setup default `resources_per_trial` and `reporter`.
|
||||||
|
if "resources_per_trial" not in kwargs and self.args.n_gpu > 0:
|
||||||
|
kwargs["resources_per_trial"] = {"gpu": self.args.n_gpu}
|
||||||
|
if "reporter" not in kwargs:
|
||||||
|
from ray.tune import CLIReporter
|
||||||
|
|
||||||
|
kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
|
||||||
|
analysis = tune.run(_objective, config=self.hp_space(None), num_samples=n_trials, **kwargs)
|
||||||
|
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
|
||||||
|
best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
|
||||||
|
self.tb_writer = _tb_writer
|
||||||
|
|
||||||
|
self.hp_search_backend = None
|
||||||
|
return best_run
|
||||||
|
|
||||||
def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
|
def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Log :obj:`logs` on the various objects watching training.
|
Log :obj:`logs` on the various objects watching training.
|
||||||
@@ -1020,8 +1209,9 @@ class Trainer:
|
|||||||
if self.args.past_index >= 0:
|
if self.args.past_index >= 0:
|
||||||
self._past = None
|
self._past = None
|
||||||
|
|
||||||
|
disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
|
||||||
samples_count = 0
|
samples_count = 0
|
||||||
for inputs in tqdm(dataloader, desc=description):
|
for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
|
||||||
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
|
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
|
||||||
batch_size = inputs[list(inputs.keys())[0]].shape[0]
|
batch_size = inputs[list(inputs.keys())[0]].shape[0]
|
||||||
samples_count += batch_size
|
samples_count += batch_size
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
import random
|
import random
|
||||||
from typing import Dict, NamedTuple, Optional
|
from typing import Any, Dict, NamedTuple, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .file_utils import is_tf_available, is_torch_available
|
from .file_utils import is_tf_available, is_torch_available
|
||||||
|
from .integrations import is_ray_available
|
||||||
|
from .tokenization_utils_base import ExplicitEnum
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed: int):
|
def set_seed(seed: int):
|
||||||
@@ -53,3 +55,70 @@ class TrainOutput(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
PREFIX_CHECKPOINT_DIR = "checkpoint"
|
PREFIX_CHECKPOINT_DIR = "checkpoint"
|
||||||
|
|
||||||
|
|
||||||
|
class BestRun(NamedTuple):
|
||||||
|
"""
|
||||||
|
The best run found by an hyperparameter search (see :class:`~transformers.Trainer.hyperparameter_search`).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
run_id (:obj:`str`):
|
||||||
|
The id of the best run (if models were saved, the corresponding checkpoint will be in the folder ending
|
||||||
|
with run-{run_id}).
|
||||||
|
objective (:obj:`float`):
|
||||||
|
The objective that was obtained for this run.
|
||||||
|
hyperparameters (:obj:`Dict[str, Any]`):
|
||||||
|
The hyperparameters picked to get this run.
|
||||||
|
"""
|
||||||
|
|
||||||
|
run_id: str
|
||||||
|
objective: float
|
||||||
|
hyperparameters: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
def default_compute_objective(metrics: Dict[str, float]) -> float:
|
||||||
|
"""
|
||||||
|
The default objective to maximize/minimize when doing an hyperparameter search. It is the evaluation loss if no
|
||||||
|
metrics are provided to the :class:`~transformers.Trainer`, the sum of all metrics otherwise.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metrics (:obj:`Dict[str, float]`): The metrics returned by the evaluate method.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`float`: The objective to minimize or maximize
|
||||||
|
"""
|
||||||
|
loss = metrics.pop("eval_loss", None)
|
||||||
|
_ = metrics.pop("epoch", None)
|
||||||
|
return loss if len(metrics) == 0 else sum(metrics.values())
|
||||||
|
|
||||||
|
|
||||||
|
def default_hp_space_optuna(trial) -> Dict[str, float]:
|
||||||
|
return {
|
||||||
|
"learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
|
||||||
|
"num_train_epochs": trial.suggest_int("num_train_epochs", 1, 5),
|
||||||
|
"seed": trial.suggest_int("seed", 1, 40),
|
||||||
|
"per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [4, 8, 16, 32, 64]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def default_hp_space_ray(trial) -> Dict[str, float]:
|
||||||
|
assert is_ray_available(), "This function needs ray installed: `pip install ray[tune]`"
|
||||||
|
from ray import tune
|
||||||
|
|
||||||
|
return {
|
||||||
|
"learning_rate": tune.loguniform(1e-6, 1e-4),
|
||||||
|
"num_train_epochs": tune.choice(range(1, 6)),
|
||||||
|
"seed": tune.uniform(1, 40),
|
||||||
|
"per_device_train_batch_size": tune.choice([4, 8, 16, 32, 64]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class HPSearchBackend(ExplicitEnum):
|
||||||
|
OPTUNA = "optuna"
|
||||||
|
RAY = "ray"
|
||||||
|
|
||||||
|
|
||||||
|
default_hp_space = {
|
||||||
|
HPSearchBackend.OPTUNA: default_hp_space_optuna,
|
||||||
|
HPSearchBackend.RAY: default_hp_space_ray,
|
||||||
|
}
|
||||||
|
|||||||
@@ -114,6 +114,9 @@ class TrainingArguments:
|
|||||||
at the next training step under the keyword argument ``mems``.
|
at the next training step under the keyword argument ``mems``.
|
||||||
run_name (:obj:`str`, `optional`):
|
run_name (:obj:`str`, `optional`):
|
||||||
A descriptor for the run. Notably used for wandb logging.
|
A descriptor for the run. Notably used for wandb logging.
|
||||||
|
disable_tqdm (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to disable the tqdm progress bars. Will default to :obj:`True` if the logging level is set
|
||||||
|
to warn or lower (default), :obj:`False` otherwise.
|
||||||
remove_unused_columns (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
remove_unused_columns (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
If using `nlp.Dataset` datasets, whether or not to automatically remove the columns unused by the model
|
If using `nlp.Dataset` datasets, whether or not to automatically remove the columns unused by the model
|
||||||
forward method.
|
forward method.
|
||||||
@@ -238,6 +241,13 @@ class TrainingArguments:
|
|||||||
run_name: Optional[str] = field(
|
run_name: Optional[str] = field(
|
||||||
default=None, metadata={"help": "An optional descriptor for the run. Notably used for wandb logging."}
|
default=None, metadata={"help": "An optional descriptor for the run. Notably used for wandb logging."}
|
||||||
)
|
)
|
||||||
|
disable_tqdm: Optional[bool] = field(
|
||||||
|
default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.disable_tqdm is None:
|
||||||
|
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
|
||||||
|
|
||||||
remove_unused_columns: Optional[bool] = field(
|
remove_unused_columns: Optional[bool] = field(
|
||||||
default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."}
|
default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."}
|
||||||
|
|||||||
Reference in New Issue
Block a user