* Should check if `torch` is available * fixed samples_count error, distributed_concat arguments * style * Import torch at beginning of file Co-authored-by: TevenLeScao <teven.lescao@gmail.com>
170 lines
5.4 KiB
Python
170 lines
5.4 KiB
Python
import random
|
|
from typing import Any, Dict, List, NamedTuple, Optional, Union
|
|
|
|
import numpy as np
|
|
|
|
from .file_utils import is_tf_available, is_torch_available
|
|
from .tokenization_utils_base import ExplicitEnum
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
|
|
def set_seed(seed: int):
|
|
"""
|
|
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf``
|
|
(if installed).
|
|
|
|
Args:
|
|
seed (:obj:`int`): The seed to set.
|
|
"""
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
# ^^ safe to call this function even if cuda is not available
|
|
if is_tf_available():
|
|
import tensorflow as tf
|
|
|
|
tf.random.set_seed(seed)
|
|
|
|
|
|
class EvalPrediction(NamedTuple):
|
|
"""
|
|
Evaluation output (always contains labels), to be used to compute metrics.
|
|
|
|
Parameters:
|
|
predictions (:obj:`np.ndarray`): Predictions of the model.
|
|
label_ids (:obj:`np.ndarray`): Targets to be matched.
|
|
"""
|
|
|
|
predictions: np.ndarray
|
|
label_ids: np.ndarray
|
|
|
|
|
|
class PredictionOutput(NamedTuple):
|
|
predictions: np.ndarray
|
|
label_ids: Optional[np.ndarray]
|
|
metrics: Optional[Dict[str, float]]
|
|
|
|
|
|
class TrainOutput(NamedTuple):
|
|
global_step: int
|
|
training_loss: float
|
|
|
|
|
|
PREFIX_CHECKPOINT_DIR = "checkpoint"
|
|
|
|
|
|
class BestRun(NamedTuple):
|
|
"""
|
|
The best run found by an hyperparameter search (see :class:`~transformers.Trainer.hyperparameter_search`).
|
|
|
|
Parameters:
|
|
run_id (:obj:`str`):
|
|
The id of the best run (if models were saved, the corresponding checkpoint will be in the folder ending
|
|
with run-{run_id}).
|
|
objective (:obj:`float`):
|
|
The objective that was obtained for this run.
|
|
hyperparameters (:obj:`Dict[str, Any]`):
|
|
The hyperparameters picked to get this run.
|
|
"""
|
|
|
|
run_id: str
|
|
objective: float
|
|
hyperparameters: Dict[str, Any]
|
|
|
|
|
|
def default_compute_objective(metrics: Dict[str, float]) -> float:
|
|
"""
|
|
The default objective to maximize/minimize when doing an hyperparameter search. It is the evaluation loss if no
|
|
metrics are provided to the :class:`~transformers.Trainer`, the sum of all metrics otherwise.
|
|
|
|
Args:
|
|
metrics (:obj:`Dict[str, float]`): The metrics returned by the evaluate method.
|
|
|
|
Return:
|
|
:obj:`float`: The objective to minimize or maximize
|
|
"""
|
|
loss = metrics.pop("eval_loss", None)
|
|
_ = metrics.pop("epoch", None)
|
|
return loss if len(metrics) == 0 else sum(metrics.values())
|
|
|
|
|
|
def default_hp_space_optuna(trial) -> Dict[str, float]:
|
|
from .integrations import is_optuna_available
|
|
|
|
assert is_optuna_available(), "This function needs Optuna installed: `pip install optuna`"
|
|
return {
|
|
"learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
|
|
"num_train_epochs": trial.suggest_int("num_train_epochs", 1, 5),
|
|
"seed": trial.suggest_int("seed", 1, 40),
|
|
"per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [4, 8, 16, 32, 64]),
|
|
}
|
|
|
|
|
|
def default_hp_space_ray(trial) -> Dict[str, float]:
|
|
from .integrations import is_ray_available
|
|
|
|
assert is_ray_available(), "This function needs ray installed: `pip install ray[tune]`"
|
|
from ray import tune
|
|
|
|
return {
|
|
"learning_rate": tune.loguniform(1e-6, 1e-4),
|
|
"num_train_epochs": tune.choice(list(range(1, 6))),
|
|
"seed": tune.uniform(1, 40),
|
|
"per_device_train_batch_size": tune.choice([4, 8, 16, 32, 64]),
|
|
}
|
|
|
|
|
|
class HPSearchBackend(ExplicitEnum):
|
|
OPTUNA = "optuna"
|
|
RAY = "ray"
|
|
|
|
|
|
default_hp_space = {
|
|
HPSearchBackend.OPTUNA: default_hp_space_optuna,
|
|
HPSearchBackend.RAY: default_hp_space_ray,
|
|
}
|
|
|
|
|
|
def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> "torch.Tensor":
|
|
if is_torch_available():
|
|
try:
|
|
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
|
|
torch.distributed.all_gather(output_tensors, tensor)
|
|
concat = torch.cat(output_tensors, dim=0)
|
|
|
|
# truncate the dummy elements added by SequentialDistributedSampler
|
|
if num_total_examples is not None:
|
|
concat = concat[:num_total_examples]
|
|
return concat
|
|
except AssertionError:
|
|
raise AssertionError("Not currently using distributed training")
|
|
else:
|
|
raise ImportError("Torch must be installed to use `distributed_concat`")
|
|
|
|
|
|
def distributed_broadcast_scalars(
|
|
scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
|
|
) -> "torch.Tensor":
|
|
if is_torch_available():
|
|
try:
|
|
tensorized_scalar = torch.Tensor(scalars).cuda()
|
|
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
|
|
torch.distributed.all_gather(output_tensors, tensorized_scalar)
|
|
concat = torch.cat(output_tensors, dim=0)
|
|
|
|
# truncate the dummy elements added by SequentialDistributedSampler
|
|
if num_total_examples is not None:
|
|
concat = concat[:num_total_examples]
|
|
return concat
|
|
except AssertionError:
|
|
raise AssertionError("Not currently using distributed training")
|
|
else:
|
|
raise ImportError("Torch must be installed to use `distributed_broadcast_scalars`")
|