1453 lines
66 KiB
Python
Executable File
1453 lines
66 KiB
Python
Executable File
import inspect
|
|
import json
|
|
import math
|
|
import os
|
|
import re
|
|
import shutil
|
|
import warnings
|
|
from contextlib import contextmanager
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
from packaging import version
|
|
from torch import nn
|
|
from torch.utils.data.dataloader import DataLoader
|
|
from torch.utils.data.dataset import Dataset
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
|
|
from tqdm.auto import tqdm, trange
|
|
|
|
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
|
from .file_utils import is_datasets_available, is_torch_tpu_available
|
|
from .integrations import (
|
|
default_hp_search_backend,
|
|
is_comet_available,
|
|
is_optuna_available,
|
|
is_ray_available,
|
|
is_tensorboard_available,
|
|
is_wandb_available,
|
|
run_hp_search_optuna,
|
|
run_hp_search_ray,
|
|
)
|
|
from .modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
|
from .modeling_utils import PreTrainedModel
|
|
from .optimization import AdamW, get_linear_schedule_with_warmup
|
|
from .tokenization_utils_base import PreTrainedTokenizerBase
|
|
from .trainer_utils import (
|
|
PREFIX_CHECKPOINT_DIR,
|
|
BestRun,
|
|
EvalPrediction,
|
|
HPSearchBackend,
|
|
PredictionOutput,
|
|
TrainOutput,
|
|
default_compute_objective,
|
|
default_hp_space,
|
|
distributed_broadcast_scalars,
|
|
distributed_concat,
|
|
nested_concat,
|
|
nested_numpify,
|
|
nested_xla_mesh_reduce,
|
|
set_seed,
|
|
)
|
|
from .training_args import TrainingArguments
|
|
from .utils import logging
|
|
|
|
|
|
_use_native_amp = False
|
|
_use_apex = False
|
|
|
|
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
|
|
if version.parse(torch.__version__) < version.parse("1.6"):
|
|
from transformers.file_utils import is_apex_available
|
|
|
|
if is_apex_available():
|
|
from apex import amp
|
|
_use_apex = True
|
|
else:
|
|
_use_native_amp = True
|
|
from torch.cuda.amp import autocast
|
|
|
|
if is_datasets_available():
|
|
import datasets
|
|
|
|
if is_torch_tpu_available():
|
|
import torch_xla.core.xla_model as xm
|
|
import torch_xla.debug.metrics as met
|
|
import torch_xla.distributed.parallel_loader as pl
|
|
|
|
if is_tensorboard_available():
|
|
try:
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
except ImportError:
|
|
from tensorboardX import SummaryWriter
|
|
|
|
if is_wandb_available():
|
|
import wandb
|
|
|
|
if is_comet_available():
|
|
import comet_ml
|
|
|
|
if is_optuna_available():
|
|
import optuna
|
|
|
|
if is_ray_available():
|
|
from ray import tune
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
@contextmanager
|
|
def torch_distributed_zero_first(local_rank: int):
|
|
"""
|
|
Decorator to make all processes in distributed training wait for each local_master to do something.
|
|
|
|
Args:
|
|
local_rank (:obj:`int`): The rank of the local process.
|
|
"""
|
|
if local_rank not in [-1, 0]:
|
|
torch.distributed.barrier()
|
|
yield
|
|
if local_rank == 0:
|
|
torch.distributed.barrier()
|
|
|
|
|
|
class SequentialDistributedSampler(Sampler):
|
|
"""
|
|
Distributed Sampler that subsamples indicies sequentially,
|
|
making it easier to collate all results at the end.
|
|
|
|
Even though we only use this sampler for eval and predict (no training),
|
|
which means that the model params won't have to be synced (i.e. will not hang
|
|
for synchronization even if varied number of forward passes), we still add extra
|
|
samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
|
|
to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
|
|
"""
|
|
|
|
def __init__(self, dataset, num_replicas=None, rank=None):
|
|
if num_replicas is None:
|
|
if not torch.distributed.is_available():
|
|
raise RuntimeError("Requires distributed package to be available")
|
|
num_replicas = torch.distributed.get_world_size()
|
|
if rank is None:
|
|
if not torch.distributed.is_available():
|
|
raise RuntimeError("Requires distributed package to be available")
|
|
rank = torch.distributed.get_rank()
|
|
self.dataset = dataset
|
|
self.num_replicas = num_replicas
|
|
self.rank = rank
|
|
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
|
self.total_size = self.num_samples * self.num_replicas
|
|
|
|
def __iter__(self):
|
|
indices = list(range(len(self.dataset)))
|
|
|
|
# add extra samples to make it evenly divisible
|
|
indices += indices[: (self.total_size - len(indices))]
|
|
assert (
|
|
len(indices) == self.total_size
|
|
), f"Indices length {len(indices)} and total size {self.total_size} mismatched"
|
|
|
|
# subsample
|
|
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
|
|
assert (
|
|
len(indices) == self.num_samples
|
|
), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
|
|
|
|
return iter(indices)
|
|
|
|
def __len__(self):
|
|
return self.num_samples
|
|
|
|
|
|
def get_tpu_sampler(dataset: Dataset):
|
|
if xm.xrt_world_size() <= 1:
|
|
return RandomSampler(dataset)
|
|
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
|
|
|
|
|
class Trainer:
|
|
"""
|
|
Trainer is a simple but feature-complete training and eval loop for PyTorch,
|
|
optimized for 🤗 Transformers.
|
|
|
|
Args:
|
|
model (:class:`~transformers.PreTrainedModel`, `optional`):
|
|
The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
|
|
args (:class:`~transformers.TrainingArguments`, `optional`):
|
|
The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
|
|
with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
|
|
data_collator (:obj:`DataCollator`, `optional`):
|
|
The function to use to form a batch from a list of elements of :obj:`train_dataset` or
|
|
:obj:`eval_dataset`. Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is
|
|
provided, an instance of :func:`~transformers.DataCollatorWithPadding` otherwise.
|
|
train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
|
The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
|
``model.forward()`` method are automatically removed.
|
|
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
|
The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
|
``model.forward()`` method are automatically removed.
|
|
tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
|
|
The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
|
|
maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
|
|
interrupted training or reuse the fine-tuned model.
|
|
model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
|
|
A function that instantiates the model to be used. If provided, each call to
|
|
:meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
|
|
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
|
|
The function that will be used to compute metrics at evaluation. Must take a
|
|
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
|
|
tb_writer (:obj:`SummaryWriter`, `optional`):
|
|
Object to write to TensorBoard.
|
|
optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`):
|
|
A tuple containing the optimizer and the scheduler to use. Will default to an instance of
|
|
:class:`~transformers.AdamW` on your model and a scheduler given by
|
|
:func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
|
|
kwargs:
|
|
Deprecated keyword arguments.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: PreTrainedModel = None,
|
|
args: TrainingArguments = None,
|
|
data_collator: Optional[DataCollator] = None,
|
|
train_dataset: Optional[Dataset] = None,
|
|
eval_dataset: Optional[Dataset] = None,
|
|
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
|
model_init: Callable[[], PreTrainedModel] = None,
|
|
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
|
tb_writer: Optional["SummaryWriter"] = None,
|
|
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
|
**kwargs,
|
|
):
|
|
if args is None:
|
|
logger.info("No `TrainingArguments` passed, using the current path as `output_dir`.")
|
|
args = TrainingArguments("tmp_trainer")
|
|
self.args = args
|
|
# Seed must be set before instantiating the model when using model
|
|
set_seed(self.args.seed)
|
|
assert (
|
|
model is not None or model_init is not None
|
|
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
|
|
if model is None and model_init is not None:
|
|
model = model_init()
|
|
self.model = model.to(args.device) if model is not None else None
|
|
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
|
self.data_collator = data_collator if data_collator is not None else default_collator
|
|
self.train_dataset = train_dataset
|
|
self.eval_dataset = eval_dataset
|
|
self.tokenizer = tokenizer
|
|
self.model_init = model_init
|
|
self.compute_metrics = compute_metrics
|
|
self.optimizer, self.lr_scheduler = optimizers
|
|
if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
|
|
raise RuntimeError(
|
|
"Passing a `model_init` is incompatible with providing the `optimizers` argument."
|
|
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
|
)
|
|
self.tb_writer = tb_writer
|
|
self.log_history = []
|
|
if "prediction_loss_only" in kwargs:
|
|
warnings.warn(
|
|
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",
|
|
FutureWarning,
|
|
)
|
|
self.args.prediction_loss_only = kwargs.pop("prediction_loss_only")
|
|
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
|
|
|
if tb_writer is None and is_tensorboard_available() and self.is_world_process_zero():
|
|
self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
|
|
if not is_tensorboard_available():
|
|
logger.warning(
|
|
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
|
|
)
|
|
|
|
# Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
|
|
self._loggers_initialized = False
|
|
|
|
# Create output directory if needed
|
|
if self.is_world_process_zero():
|
|
os.makedirs(self.args.output_dir, exist_ok=True)
|
|
if is_torch_tpu_available():
|
|
# Set an xla_device flag on the model's config.
|
|
# We'll find a more elegant and not need to do this in the future.
|
|
self.model.config.xla_device = True
|
|
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
|
|
self.data_collator = self.data_collator.collate_batch
|
|
warnings.warn(
|
|
(
|
|
"The `data_collator` should now be a simple callable (function, class with `__call__`), classes "
|
|
+ "with a `collate_batch` are deprecated and won't be supported in a future version."
|
|
),
|
|
FutureWarning,
|
|
)
|
|
|
|
if is_datasets_available():
|
|
if isinstance(train_dataset, datasets.Dataset):
|
|
self._remove_unused_columns(self.train_dataset, description="training")
|
|
if isinstance(eval_dataset, datasets.Dataset):
|
|
self._remove_unused_columns(self.eval_dataset, description="evaluation")
|
|
|
|
self.global_step = None
|
|
self.epoch = None
|
|
self.total_flos = None
|
|
if self.args.fp16 and _use_native_amp:
|
|
self.scaler = torch.cuda.amp.GradScaler()
|
|
self.hp_search_backend = None
|
|
self.use_tune_checkpoints = False
|
|
if self.args.label_names is None:
|
|
self.args.label_names = (
|
|
["start_positions, end_positions"]
|
|
if type(self.model) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values()
|
|
else ["labels"]
|
|
)
|
|
|
|
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
|
|
if not self.args.remove_unused_columns:
|
|
return
|
|
# Inspect model forward signature to keep only the arguments it accepts.
|
|
signature = inspect.signature(self.model.forward)
|
|
signature_columns = list(signature.parameters.keys())
|
|
# Labels may be named label or label_ids, the default data collator handles that.
|
|
signature_columns += ["label", "label_ids"]
|
|
columns = [k for k in signature_columns if k in dataset.column_names]
|
|
ignored_columns = list(set(dataset.column_names) - set(signature_columns))
|
|
dset_description = "" if description is None else f"in the {description} set "
|
|
logger.info(
|
|
f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
|
|
)
|
|
dataset.set_format(type=dataset.format["type"], columns=columns)
|
|
|
|
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
|
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
|
return None
|
|
elif is_torch_tpu_available():
|
|
return get_tpu_sampler(self.train_dataset)
|
|
else:
|
|
return (
|
|
RandomSampler(self.train_dataset)
|
|
if self.args.local_rank == -1
|
|
else DistributedSampler(self.train_dataset)
|
|
)
|
|
|
|
def get_train_dataloader(self) -> DataLoader:
|
|
"""
|
|
Returns the training :class:`~torch.utils.data.DataLoader`.
|
|
|
|
Will use no sampler if :obj:`self.train_dataset` is a :obj:`torch.utils.data.IterableDataset`, a random sampler
|
|
(adapted to distributed training if necessary) otherwise.
|
|
|
|
Subclass and override this method if you want to inject some custom behavior.
|
|
"""
|
|
if self.train_dataset is None:
|
|
raise ValueError("Trainer: training requires a train_dataset.")
|
|
train_sampler = self._get_train_sampler()
|
|
|
|
return DataLoader(
|
|
self.train_dataset,
|
|
batch_size=self.args.train_batch_size,
|
|
sampler=train_sampler,
|
|
collate_fn=self.data_collator,
|
|
drop_last=self.args.dataloader_drop_last,
|
|
)
|
|
|
|
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
|
|
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
|
return None
|
|
elif is_torch_tpu_available():
|
|
return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
|
elif self.args.local_rank != -1:
|
|
return SequentialDistributedSampler(eval_dataset)
|
|
else:
|
|
return SequentialSampler(eval_dataset)
|
|
|
|
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
|
"""
|
|
Returns the evaluation :class:`~torch.utils.data.DataLoader`.
|
|
|
|
Will use no sampler if :obj:`self.eval_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
|
|
sampler (adapted to distributed training if necessary) otherwise.
|
|
|
|
Subclass and override this method if you want to inject some custom behavior.
|
|
|
|
Args:
|
|
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
|
If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
|
|
accepted by the ``model.forward()`` method are automatically removed.
|
|
"""
|
|
if eval_dataset is None and self.eval_dataset is None:
|
|
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
|
elif eval_dataset is not None and is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
|
|
self._remove_unused_columns(eval_dataset, description="evaluation")
|
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
|
eval_sampler = self._get_eval_sampler(eval_dataset)
|
|
|
|
return DataLoader(
|
|
eval_dataset,
|
|
sampler=eval_sampler,
|
|
batch_size=self.args.eval_batch_size,
|
|
collate_fn=self.data_collator,
|
|
drop_last=self.args.dataloader_drop_last,
|
|
)
|
|
|
|
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
|
|
"""
|
|
Returns the test :class:`~torch.utils.data.DataLoader`.
|
|
|
|
Will use no sampler if :obj:`test_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
|
|
sampler (adapted to distributed training if necessary) otherwise.
|
|
|
|
Subclass and override this method if you want to inject some custom behavior.
|
|
|
|
Args:
|
|
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
|
The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
|
``model.forward()`` method are automatically removed.
|
|
"""
|
|
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
|
|
self._remove_unused_columns(test_dataset, description="test")
|
|
test_sampler = self._get_eval_sampler(test_dataset)
|
|
|
|
# We use the same batch_size as for eval.
|
|
return DataLoader(
|
|
test_dataset,
|
|
sampler=test_sampler,
|
|
batch_size=self.args.eval_batch_size,
|
|
collate_fn=self.data_collator,
|
|
drop_last=self.args.dataloader_drop_last,
|
|
)
|
|
|
|
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
|
"""
|
|
Setup the optimizer and the learning rate scheduler.
|
|
|
|
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
|
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
|
|
"""
|
|
if self.optimizer is None:
|
|
no_decay = ["bias", "LayerNorm.weight"]
|
|
optimizer_grouped_parameters = [
|
|
{
|
|
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
|
|
"weight_decay": self.args.weight_decay,
|
|
},
|
|
{
|
|
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
|
|
"weight_decay": 0.0,
|
|
},
|
|
]
|
|
self.optimizer = AdamW(
|
|
optimizer_grouped_parameters,
|
|
lr=self.args.learning_rate,
|
|
betas=(self.args.adam_beta1, self.args.adam_beta2),
|
|
eps=self.args.adam_epsilon,
|
|
)
|
|
if self.lr_scheduler is None:
|
|
self.lr_scheduler = get_linear_schedule_with_warmup(
|
|
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
|
|
)
|
|
|
|
def setup_wandb(self):
|
|
"""
|
|
Setup the optional Weights & Biases (`wandb`) integration.
|
|
|
|
One can subclass and override this method to customize the setup if needed. Find more information
|
|
`here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
|
|
|
|
Environment:
|
|
WANDB_WATCH:
|
|
(Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
|
|
or "all" to log gradients and parameters
|
|
WANDB_PROJECT:
|
|
(Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
|
|
WANDB_DISABLED:
|
|
(Optional): boolean - defaults to false, set to "true" to disable wandb entirely
|
|
"""
|
|
if hasattr(self, "_setup_wandb"):
|
|
warnings.warn(
|
|
"The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
|
|
FutureWarning,
|
|
)
|
|
return self._setup_wandb()
|
|
|
|
if self.is_world_process_zero():
|
|
logger.info(
|
|
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
|
|
)
|
|
try:
|
|
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
|
|
except AttributeError:
|
|
# in case the model has no config
|
|
combined_dict = {**self.args.to_sanitized_dict()}
|
|
wandb.init(
|
|
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
|
|
)
|
|
# keep track of model topology and gradients, unsupported on TPU
|
|
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
|
|
wandb.watch(
|
|
self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
|
|
)
|
|
|
|
def setup_comet(self):
|
|
"""
|
|
Setup the optional Comet.ml integration.
|
|
|
|
Environment:
|
|
COMET_MODE:
|
|
(Optional): str - "OFFLINE", "ONLINE", or "DISABLED"
|
|
COMET_PROJECT_NAME:
|
|
(Optional): str - Comet.ml project name for experiments
|
|
COMET_OFFLINE_DIRECTORY:
|
|
(Optional): str - folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE"
|
|
|
|
For a number of configurable items in the environment,
|
|
see `here <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__
|
|
"""
|
|
if self.is_world_master():
|
|
comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
|
|
args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
|
|
experiment = None
|
|
if comet_mode == "ONLINE":
|
|
experiment = comet_ml.Experiment(**args)
|
|
logger.info("Automatic Comet.ml online logging enabled")
|
|
elif comet_mode == "OFFLINE":
|
|
args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
|
|
experiment = comet_ml.OfflineExperiment(**args)
|
|
logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
|
|
if experiment is not None:
|
|
experiment._set_model_graph(self.model, framework="transformers")
|
|
experiment._log_parameters(self.args, prefix="args/", framework="transformers")
|
|
experiment._log_parameters(self.model.config, prefix="config/", framework="transformers")
|
|
|
|
def num_examples(self, dataloader: DataLoader) -> int:
|
|
"""
|
|
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
|
|
"""
|
|
return len(dataloader.dataset)
|
|
|
|
def _setup_loggers(self):
|
|
if self._loggers_initialized:
|
|
return
|
|
if is_wandb_available():
|
|
self.setup_wandb()
|
|
elif os.environ.get("WANDB_DISABLED") != "true":
|
|
logger.info(
|
|
"You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
|
|
"run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
|
|
)
|
|
if is_comet_available():
|
|
self.setup_comet()
|
|
elif os.environ.get("COMET_MODE") != "DISABLED":
|
|
logger.info(
|
|
"To use comet_ml logging, run `pip/conda install comet_ml` "
|
|
"see https://www.comet.ml/docs/python-sdk/huggingface/"
|
|
)
|
|
self._loggers_initialized = True
|
|
|
|
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
|
|
""" HP search setup code """
|
|
if self.hp_search_backend is None or trial is None:
|
|
return
|
|
params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
|
|
for key, value in params.items():
|
|
if not hasattr(self.args, key):
|
|
raise AttributeError(
|
|
f"Trying to set {key} in the hyperparameter search but there is no corresponding field in `TrainingArguments`."
|
|
)
|
|
old_attr = getattr(self.args, key, None)
|
|
# Casting value to the proper type
|
|
if old_attr is not None:
|
|
value = type(old_attr)(value)
|
|
setattr(self.args, key, value)
|
|
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
|
logger.info("Trial:", trial.params)
|
|
|
|
def _report_to_hp_search(
|
|
self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
|
|
):
|
|
if self.hp_search_backend is None or trial is None:
|
|
return
|
|
self.objective = self.compute_objective(metrics)
|
|
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
|
trial.report(self.objective, epoch)
|
|
if trial.should_prune():
|
|
raise optuna.TrialPruned()
|
|
elif self.hp_search_backend == HPSearchBackend.RAY:
|
|
if self.global_step % self.args.save_steps == 0:
|
|
self._tune_save_checkpoint()
|
|
tune.report(objective=self.objective, **metrics)
|
|
|
|
def _tune_save_checkpoint(self):
|
|
if not self.use_tune_checkpoints:
|
|
return
|
|
with tune.checkpoint_dir(step=self.global_step) as checkpoint_dir:
|
|
self.args.output_dir = checkpoint_dir
|
|
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")
|
|
self.save_model(output_dir)
|
|
if self.is_world_master():
|
|
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
|
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
|
|
|
def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
|
|
"""
|
|
Main training entry point.
|
|
|
|
Args:
|
|
model_path (:obj:`str`, `optional`):
|
|
Local path to the model if the model to train has been instantiated from a local path. If present,
|
|
training will resume from the optimizer/scheduler states loaded here.
|
|
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
|
|
The trial run or the hyperparameter dictionary for hyperparameter search.
|
|
"""
|
|
# This might change the seed so needs to run first.
|
|
self._hp_search_setup(trial)
|
|
|
|
# Model re-init
|
|
if self.model_init is not None:
|
|
# Seed must be set before instantiating the model when using model_init.
|
|
set_seed(self.args.seed)
|
|
model = self.model_init()
|
|
self.model = model.to(self.args.device)
|
|
|
|
# Reinitializes optimizer and scheduler
|
|
self.optimizer, self.lr_scheduler = None, None
|
|
|
|
# Data loader and number of training steps
|
|
train_dataloader = self.get_train_dataloader()
|
|
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
|
|
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
|
|
if self.args.max_steps > 0:
|
|
t_total = self.args.max_steps
|
|
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
|
|
self.args.max_steps % num_update_steps_per_epoch > 0
|
|
)
|
|
else:
|
|
t_total = int(num_update_steps_per_epoch * self.args.num_train_epochs)
|
|
num_train_epochs = self.args.num_train_epochs
|
|
self.args.max_steps = t_total
|
|
|
|
self.create_optimizer_and_scheduler(num_training_steps=t_total)
|
|
|
|
# Check if saved optimizer or scheduler states exist
|
|
if (
|
|
model_path is not None
|
|
and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
|
|
and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
|
|
):
|
|
# Load in optimizer and scheduler states
|
|
self.optimizer.load_state_dict(
|
|
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
|
|
)
|
|
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
|
|
|
model = self.model
|
|
if self.args.fp16 and _use_apex:
|
|
if not is_apex_available():
|
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
|
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
|
|
|
# multi-gpu training (should be after apex fp16 initialization)
|
|
if self.args.n_gpu > 1:
|
|
model = torch.nn.DataParallel(model)
|
|
|
|
# Distributed training (should be after apex fp16 initialization)
|
|
if self.args.local_rank != -1:
|
|
model = torch.nn.parallel.DistributedDataParallel(
|
|
model,
|
|
device_ids=[self.args.local_rank],
|
|
output_device=self.args.local_rank,
|
|
find_unused_parameters=True,
|
|
)
|
|
|
|
if self.tb_writer is not None:
|
|
self.tb_writer.add_text("args", self.args.to_json_string())
|
|
self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
|
|
|
|
# Train!
|
|
if is_torch_tpu_available():
|
|
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
|
|
else:
|
|
total_train_batch_size = (
|
|
self.args.train_batch_size
|
|
* self.args.gradient_accumulation_steps
|
|
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
|
|
)
|
|
logger.info("***** Running training *****")
|
|
logger.info(" Num examples = %d", self.num_examples(train_dataloader))
|
|
logger.info(" Num Epochs = %d", num_train_epochs)
|
|
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
|
|
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
|
|
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
|
|
logger.info(" Total optimization steps = %d", t_total)
|
|
|
|
self.global_step = 0
|
|
self.epoch = 0
|
|
self.total_flos = 0
|
|
epochs_trained = 0
|
|
steps_trained_in_current_epoch = 0
|
|
# Check if continuing training from a checkpoint
|
|
if model_path is not None:
|
|
# set global_step to global_step of last saved checkpoint from model path
|
|
try:
|
|
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
|
|
self.total_flos = getattr(model.config, "total_flos", 0)
|
|
|
|
epochs_trained = self.global_step // num_update_steps_per_epoch
|
|
steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)
|
|
|
|
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
|
logger.info(" Continuing training from epoch %d", epochs_trained)
|
|
logger.info(" Continuing training from global step %d", self.global_step)
|
|
logger.info(" Continuing training from %d non-embedding floating-point operations", self.total_flos)
|
|
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
|
except ValueError:
|
|
self.global_step = 0
|
|
self.total_flos = 0
|
|
logger.info(" Starting fine-tuning.")
|
|
|
|
tr_loss = torch.tensor(0.0).to(self.args.device)
|
|
logging_loss_scalar = 0.0
|
|
model.zero_grad()
|
|
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
|
|
train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
|
|
for epoch in range(epochs_trained, int(np.ceil(num_train_epochs))):
|
|
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
|
train_dataloader.sampler.set_epoch(epoch)
|
|
|
|
if is_torch_tpu_available():
|
|
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
|
|
self.args.device
|
|
)
|
|
epoch_iterator = parallel_loader
|
|
else:
|
|
epoch_iterator = train_dataloader
|
|
|
|
# Reset the past mems state at the beginning of each epoch if necessary.
|
|
if self.args.past_index >= 0:
|
|
self._past = None
|
|
|
|
epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm)
|
|
for step, inputs in enumerate(epoch_iterator):
|
|
|
|
# Skip past any already trained steps if resuming training
|
|
if steps_trained_in_current_epoch > 0:
|
|
steps_trained_in_current_epoch -= 1
|
|
epoch_pbar.update(1)
|
|
continue
|
|
|
|
tr_loss += self.training_step(model, inputs)
|
|
self.total_flos += self.floating_point_ops(inputs)
|
|
|
|
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
|
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
|
len(epoch_iterator) <= self.args.gradient_accumulation_steps
|
|
and (step + 1) == len(epoch_iterator)
|
|
):
|
|
if self.args.fp16 and _use_native_amp:
|
|
self.scaler.unscale_(self.optimizer)
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
|
|
elif self.args.fp16 and _use_apex:
|
|
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
|
|
else:
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
|
|
|
|
if is_torch_tpu_available():
|
|
xm.optimizer_step(self.optimizer)
|
|
elif self.args.fp16 and _use_native_amp:
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
else:
|
|
self.optimizer.step()
|
|
|
|
self.lr_scheduler.step()
|
|
model.zero_grad()
|
|
self.global_step += 1
|
|
self.epoch = epoch + (step + 1) / len(epoch_iterator)
|
|
|
|
if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
|
|
self.global_step == 1 and self.args.logging_first_step
|
|
):
|
|
logs: Dict[str, float] = {}
|
|
tr_loss_scalar = tr_loss.item()
|
|
logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps
|
|
# backward compatibility for pytorch schedulers
|
|
logs["learning_rate"] = (
|
|
self.lr_scheduler.get_last_lr()[0]
|
|
if version.parse(torch.__version__) >= version.parse("1.4")
|
|
else self.lr_scheduler.get_lr()[0]
|
|
)
|
|
logging_loss_scalar = tr_loss_scalar
|
|
|
|
self.log(logs)
|
|
|
|
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
|
|
metrics = self.evaluate()
|
|
self._report_to_hp_search(trial, epoch, metrics)
|
|
|
|
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
|
|
# In all cases (even distributed/parallel), self.model is always a reference
|
|
# to the model we want to save.
|
|
if hasattr(model, "module"):
|
|
assert (
|
|
model.module is self.model
|
|
), f"Module {model.module} should be a reference to self.model"
|
|
else:
|
|
assert model is self.model, f"Model {model} should be a reference to self.model"
|
|
# Save model checkpoint
|
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
|
|
if self.hp_search_backend is not None and trial is not None:
|
|
run_id = (
|
|
trial.number
|
|
if self.hp_search_backend == HPSearchBackend.OPTUNA
|
|
else tune.get_trial_id()
|
|
)
|
|
checkpoint_folder += f"-run-{run_id}"
|
|
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
|
|
|
|
self.save_model(output_dir)
|
|
|
|
if self.is_world_process_zero():
|
|
self._rotate_checkpoints(use_mtime=True)
|
|
|
|
if is_torch_tpu_available():
|
|
xm.rendezvous("saving_optimizer_states")
|
|
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
|
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
|
elif self.is_world_process_zero():
|
|
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
|
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
|
|
|
epoch_pbar.update(1)
|
|
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
|
|
break
|
|
epoch_pbar.close()
|
|
train_pbar.update(1)
|
|
if self.args.tpu_metrics_debug or self.args.debug:
|
|
if is_torch_tpu_available():
|
|
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
|
xm.master_print(met.metrics_report())
|
|
else:
|
|
logger.warning(
|
|
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
|
|
"configured. Check your training configuration if this is unexpected."
|
|
)
|
|
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
|
|
break
|
|
|
|
train_pbar.close()
|
|
if self.tb_writer:
|
|
self.tb_writer.close()
|
|
if self.args.past_index and hasattr(self, "_past"):
|
|
# Clean the state at the end of training
|
|
delattr(self, "_past")
|
|
|
|
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
|
|
return TrainOutput(self.global_step, tr_loss.item() / self.global_step)
|
|
|
|
def hyperparameter_search(
|
|
self,
|
|
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
|
|
compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
|
|
n_trials: int = 20,
|
|
direction: str = "minimize",
|
|
backend: Optional[Union["str", HPSearchBackend]] = None,
|
|
**kwargs
|
|
) -> BestRun:
|
|
"""
|
|
Launch an hyperparameter search using ``optuna`` or ``Ray Tune``. The optimized quantity is determined by
|
|
:obj:`compute_objectie`, which defaults to a function returning the evaluation loss when no metric is provided,
|
|
the sum of all metrics otherwise.
|
|
|
|
.. warning::
|
|
|
|
To use this method, you need to have provided a ``model_init`` when initializing your
|
|
:class:`~transformers.Trainer`: we need to reinitialize the model at each new run. This is incompatible
|
|
with the ``optimizers`` argument, so you need to subclass :class:`~transformers.Trainer` and override the
|
|
method :meth:`~transformers.Trainer.create_optimizer_and_scheduler` for custom optimizer/scheduler.
|
|
|
|
Args:
|
|
hp_space (:obj:`Callable[["optuna.Trial"], Dict[str, float]]`, `optional`):
|
|
A function that defines the hyperparameter search space. Will default to
|
|
:func:`~transformers.trainer_utils.default_hp_space_optuna` or
|
|
:func:`~transformers.trainer_utils.default_hp_space_ray` depending on your backend.
|
|
compute_objective (:obj:`Callable[[Dict[str, float]], float]`, `optional`):
|
|
A function computing the objective to minimize or maximize from the metrics returned by the
|
|
:obj:`evaluate` method. Will default to :func:`~transformers.trainer_utils.default_compute_objective`.
|
|
n_trials (:obj:`int`, `optional`, defaults to 100):
|
|
The number of trial runs to test.
|
|
direction(:obj:`str`, `optional`, defaults to :obj:`"minimize"`):
|
|
Whether to optimize greater or lower objects. Can be :obj:`"minimize"` or :obj:`"maximize"`, you should
|
|
pick :obj:`"minimize"` when optimizing the validation loss, :obj:`"maximize"` when optimizing one or
|
|
several metrics.
|
|
backend(:obj:`str` or :class:`~transformers.training_utils.HPSearchBackend`, `optional`):
|
|
The backend to use for hyperparameter search. Will default to optuna or Ray Tune, depending on which
|
|
one is installed. If both are installed, will default to optuna.
|
|
kwargs:
|
|
Additional keyword arguments passed along to :obj:`optuna.create_study` or :obj:`ray.tune.run`. For
|
|
more information see:
|
|
|
|
- the documentation of `optuna.create_study <https://optuna.readthedocs.io/en/stable/reference/alias_generated/optuna.create_study.html#optuna.create_study>`__
|
|
- the documentation of `tune.run <https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run>`__
|
|
|
|
Returns:
|
|
:class:`transformers.trainer_utils.BestRun`: All the informations about the best run.
|
|
"""
|
|
if backend is None:
|
|
backend = default_hp_search_backend()
|
|
if backend is None:
|
|
raise RuntimeError(
|
|
"At least one of optuna or ray should be installed. "
|
|
"To install optuna run `pip install optuna`."
|
|
"To install ray run `pip install ray[tune]`."
|
|
)
|
|
backend = HPSearchBackend(backend)
|
|
if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
|
|
raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
|
|
if backend == HPSearchBackend.RAY and not is_ray_available():
|
|
raise RuntimeError(
|
|
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
|
|
)
|
|
self.hp_search_backend = backend
|
|
|
|
if self.model_init is None:
|
|
raise RuntimeError(
|
|
"To use hyperparameter search, you need to pass your model through a model_init function."
|
|
)
|
|
|
|
self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
|
|
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
|
|
|
|
run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
|
|
best_run = run_hp_search(self, n_trials, direction, **kwargs)
|
|
|
|
self.hp_search_backend = None
|
|
return best_run
|
|
|
|
def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
|
|
"""
|
|
Log :obj:`logs` on the various objects watching training.
|
|
|
|
Subclass and override this method to inject custom behavior.
|
|
|
|
Args:
|
|
logs (:obj:`Dict[str, float]`):
|
|
The values to log.
|
|
iterator (:obj:`tqdm`, `optional`):
|
|
A potential tqdm progress bar to write the logs on.
|
|
"""
|
|
# Set up loggers like W&B or Comet ML
|
|
self._setup_loggers()
|
|
|
|
if hasattr(self, "_log"):
|
|
warnings.warn(
|
|
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
|
|
FutureWarning,
|
|
)
|
|
return self._log(logs, iterator=iterator)
|
|
|
|
if self.epoch is not None:
|
|
logs["epoch"] = self.epoch
|
|
if self.total_flos is not None:
|
|
if self.args.local_rank != -1:
|
|
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
|
|
else:
|
|
total_flos = self.total_flos
|
|
if total_flos > 0:
|
|
logs["total_flos"] = self.total_flos
|
|
if self.global_step is None:
|
|
# when logging evaluation metrics without training
|
|
self.global_step = 0
|
|
if self.tb_writer:
|
|
for k, v in logs.items():
|
|
if isinstance(v, (int, float)):
|
|
self.tb_writer.add_scalar(k, v, self.global_step)
|
|
else:
|
|
logger.warning(
|
|
"Trainer is attempting to log a value of "
|
|
'"%s" of type %s for key "%s" as a scalar. '
|
|
"This invocation of Tensorboard's writer.add_scalar() "
|
|
"is incorrect so we dropped this attribute.",
|
|
v,
|
|
type(v),
|
|
k,
|
|
)
|
|
self.tb_writer.flush()
|
|
if is_wandb_available():
|
|
if self.is_world_process_zero():
|
|
wandb.log(logs, step=self.global_step)
|
|
if is_comet_available():
|
|
if self.is_world_process_zero():
|
|
experiment = comet_ml.config.get_global_experiment()
|
|
if experiment is not None:
|
|
experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers")
|
|
output = {**logs, **{"step": self.global_step}}
|
|
if self.is_world_process_zero():
|
|
self.log_history.append(output)
|
|
if iterator is not None:
|
|
iterator.write(output)
|
|
else:
|
|
print(output)
|
|
|
|
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
|
|
"""
|
|
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
|
|
handling potential state.
|
|
"""
|
|
for k, v in inputs.items():
|
|
if isinstance(v, torch.Tensor):
|
|
inputs[k] = v.to(self.args.device)
|
|
|
|
if self.args.past_index >= 0 and self._past is not None:
|
|
inputs["mems"] = self._past
|
|
|
|
return inputs
|
|
|
|
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
|
|
"""
|
|
Perform a training step on a batch of inputs.
|
|
|
|
Subclass and override to inject custom behavior.
|
|
|
|
Args:
|
|
model (:obj:`nn.Module`):
|
|
The model to train.
|
|
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
|
The inputs and targets of the model.
|
|
|
|
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
|
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
|
|
|
Return:
|
|
:obj:`torch.Tensor`: The tensor with training loss on this batch.
|
|
"""
|
|
if hasattr(self, "_training_step"):
|
|
warnings.warn(
|
|
"The `_training_step` method is deprecated and won't be called in a future version, define `training_step` in your subclass.",
|
|
FutureWarning,
|
|
)
|
|
return self._training_step(model, inputs, self.optimizer)
|
|
|
|
model.train()
|
|
inputs = self._prepare_inputs(inputs)
|
|
|
|
if self.args.fp16 and _use_native_amp:
|
|
with autocast():
|
|
loss = self.compute_loss(model, inputs)
|
|
else:
|
|
loss = self.compute_loss(model, inputs)
|
|
|
|
if self.args.n_gpu > 1:
|
|
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
|
|
|
if self.args.gradient_accumulation_steps > 1:
|
|
loss = loss / self.args.gradient_accumulation_steps
|
|
|
|
if self.args.fp16 and _use_native_amp:
|
|
self.scaler.scale(loss).backward()
|
|
elif self.args.fp16 and _use_apex:
|
|
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
|
scaled_loss.backward()
|
|
else:
|
|
loss.backward()
|
|
|
|
return loss.detach()
|
|
|
|
def compute_loss(self, model, inputs):
|
|
"""
|
|
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
|
|
|
Subclass and override for custom behavior.
|
|
"""
|
|
outputs = model(**inputs)
|
|
# Save past state if it exists
|
|
if self.args.past_index >= 0:
|
|
self._past = outputs[self.args.past_index]
|
|
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
|
return outputs[0]
|
|
|
|
def is_local_master(self) -> bool:
|
|
"""
|
|
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
|
|
several machines) main process.
|
|
|
|
.. warning::
|
|
|
|
This method is deprecated, use :meth:`~transformers.Trainer.is_local_process_zero` instead.
|
|
"""
|
|
warnings.warn("This method is deprecated, use `Trainer.is_local_process_zero()` instead.", FutureWarning)
|
|
return self.is_local_process_zero()
|
|
|
|
def is_local_process_zero(self) -> bool:
|
|
"""
|
|
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
|
|
several machines) main process.
|
|
"""
|
|
if is_torch_tpu_available():
|
|
return xm.is_master_ordinal(local=True)
|
|
else:
|
|
return self.args.local_rank in [-1, 0]
|
|
|
|
def is_world_master(self) -> bool:
|
|
"""
|
|
Whether or not this process is the global main process (when training in a distributed fashion on
|
|
several machines, this is only going to be :obj:`True` for one process).
|
|
|
|
.. warning::
|
|
|
|
This method is deprecated, use :meth:`~transformers.Trainer.is_world_process_zero` instead.
|
|
"""
|
|
warnings.warn("This method is deprecated, use `Trainer.is_world_process_zero()` instead.", FutureWarning)
|
|
return self.is_world_process_zero()
|
|
|
|
def is_world_process_zero(self) -> bool:
|
|
"""
|
|
Whether or not this process is the global main process (when training in a distributed fashion on
|
|
several machines, this is only going to be :obj:`True` for one process).
|
|
"""
|
|
if is_torch_tpu_available():
|
|
return xm.is_master_ordinal(local=False)
|
|
else:
|
|
return self.args.local_rank == -1 or torch.distributed.get_rank() == 0
|
|
|
|
def save_model(self, output_dir: Optional[str] = None):
|
|
"""
|
|
Will save the model, so you can reload it using :obj:`from_pretrained()`.
|
|
|
|
Will only save from the world_master process (unless in TPUs).
|
|
"""
|
|
|
|
if is_torch_tpu_available():
|
|
self._save_tpu(output_dir)
|
|
elif self.is_world_process_zero():
|
|
self._save(output_dir)
|
|
|
|
def _save_tpu(self, output_dir: Optional[str] = None):
|
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
|
logger.info("Saving model checkpoint to %s", output_dir)
|
|
|
|
if xm.is_master_ordinal():
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
|
json.dump(
|
|
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
|
|
)
|
|
|
|
# Save a trained model and configuration using `save_pretrained()`.
|
|
# They can then be reloaded using `from_pretrained()`
|
|
if not isinstance(self.model, PreTrainedModel):
|
|
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
|
|
|
|
xm.rendezvous("saving_checkpoint")
|
|
self._store_flos()
|
|
self.model.save_pretrained(output_dir)
|
|
if self.tokenizer is not None:
|
|
self.tokenizer.save_pretrained(output_dir)
|
|
|
|
def _save(self, output_dir: Optional[str] = None):
|
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
logger.info("Saving model checkpoint to %s", output_dir)
|
|
# Save a trained model and configuration using `save_pretrained()`.
|
|
# They can then be reloaded using `from_pretrained()`
|
|
if not isinstance(self.model, PreTrainedModel):
|
|
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
|
|
self._store_flos()
|
|
self.model.save_pretrained(output_dir)
|
|
if self.tokenizer is not None:
|
|
self.tokenizer.save_pretrained(output_dir)
|
|
|
|
# Good practice: save your training arguments together with the trained model
|
|
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
|
json.dump(
|
|
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
|
|
)
|
|
|
|
def _store_flos(self):
|
|
# Storing the number of floating-point operations that went into the model
|
|
if self.total_flos is not None:
|
|
if self.args.local_rank != -1:
|
|
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
|
|
else:
|
|
total_flos = self.total_flos
|
|
if total_flos > 0:
|
|
self.model.config.total_flos = total_flos
|
|
|
|
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
|
|
ordering_and_checkpoint_path = []
|
|
|
|
glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")]
|
|
|
|
for path in glob_checkpoints:
|
|
if use_mtime:
|
|
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
|
|
else:
|
|
regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
|
|
if regex_match and regex_match.groups():
|
|
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
|
|
|
|
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
|
|
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
|
|
return checkpoints_sorted
|
|
|
|
def _rotate_checkpoints(self, use_mtime=False) -> None:
|
|
if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
|
|
return
|
|
|
|
# Check if we should delete older checkpoint(s)
|
|
checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
|
|
if len(checkpoints_sorted) <= self.args.save_total_limit:
|
|
return
|
|
|
|
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit)
|
|
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
|
|
for checkpoint in checkpoints_to_be_deleted:
|
|
logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
|
|
shutil.rmtree(checkpoint)
|
|
|
|
def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
|
|
"""
|
|
Run evaluation and returns metrics.
|
|
|
|
The calling script will be responsible for providing a method to compute metrics, as they are
|
|
task-dependent (pass it to the init :obj:`compute_metrics` argument).
|
|
|
|
You can also subclass and override this method to inject custom behavior.
|
|
|
|
Args:
|
|
eval_dataset (:obj:`Dataset`, `optional`):
|
|
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
|
|
columns not accepted by the ``model.forward()`` method are automatically removed.
|
|
|
|
Returns:
|
|
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
|
|
"""
|
|
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
|
|
|
output = self.prediction_loop(eval_dataloader, description="Evaluation")
|
|
|
|
self.log(output.metrics)
|
|
|
|
if self.args.tpu_metrics_debug or self.args.debug:
|
|
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
|
xm.master_print(met.metrics_report())
|
|
|
|
return output.metrics
|
|
|
|
def predict(self, test_dataset: Dataset) -> PredictionOutput:
|
|
"""
|
|
Run prediction and returns predictions and potential metrics.
|
|
|
|
Depending on the dataset and your use case, your test dataset may contain labels.
|
|
In that case, this method will also return metrics, like in :obj:`evaluate()`.
|
|
|
|
Args:
|
|
test_dataset (:obj:`Dataset`):
|
|
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
|
``model.forward()`` method are automatically removed.
|
|
|
|
Returns:
|
|
`NamedTuple`:
|
|
predictions (:obj:`np.ndarray`):
|
|
The predictions on :obj:`test_dataset`.
|
|
label_ids (:obj:`np.ndarray`, `optional`):
|
|
The labels (if the dataset contained some).
|
|
metrics (:obj:`Dict[str, float]`, `optional`):
|
|
The potential dictionary of metrics (if the dataset contained labels).
|
|
"""
|
|
test_dataloader = self.get_test_dataloader(test_dataset)
|
|
|
|
return self.prediction_loop(test_dataloader, description="Prediction")
|
|
|
|
def prediction_loop(
|
|
self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
|
|
) -> PredictionOutput:
|
|
"""
|
|
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
|
|
|
|
Works both with or without labels.
|
|
"""
|
|
if hasattr(self, "_prediction_loop"):
|
|
warnings.warn(
|
|
"The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
|
|
FutureWarning,
|
|
)
|
|
return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)
|
|
|
|
prediction_loss_only = (
|
|
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
|
|
)
|
|
|
|
assert not getattr(
|
|
self.model.config, "output_attentions", False
|
|
), "The prediction loop does not work with `output_attentions=True`."
|
|
assert not getattr(
|
|
self.model.config, "output_hidden_states", False
|
|
), "The prediction loop does not work with `output_hidden_states=True`."
|
|
|
|
model = self.model
|
|
# multi-gpu eval
|
|
if self.args.n_gpu > 1:
|
|
model = torch.nn.DataParallel(model)
|
|
else:
|
|
model = self.model
|
|
# Note: in torch.distributed mode, there's no point in wrapping the model
|
|
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
|
|
|
|
batch_size = dataloader.batch_size
|
|
logger.info("***** Running %s *****", description)
|
|
logger.info(" Num examples = %d", self.num_examples(dataloader))
|
|
logger.info(" Batch size = %d", batch_size)
|
|
eval_losses: List[float] = []
|
|
preds: torch.Tensor = None
|
|
label_ids: torch.Tensor = None
|
|
model.eval()
|
|
|
|
if is_torch_tpu_available():
|
|
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
|
|
|
|
if self.args.past_index >= 0:
|
|
self._past = None
|
|
|
|
disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
|
|
for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
|
|
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
|
|
batch_size = inputs[list(inputs.keys())[0]].shape[0]
|
|
if loss is not None:
|
|
eval_losses.extend([loss] * batch_size)
|
|
if logits is not None:
|
|
preds = logits if preds is None else nested_concat(preds, logits, dim=0)
|
|
if labels is not None:
|
|
label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0)
|
|
|
|
if self.args.past_index and hasattr(self, "_past"):
|
|
# Clean the state at the end of the evaluation loop
|
|
delattr(self, "_past")
|
|
|
|
if self.args.local_rank != -1:
|
|
# In distributed mode, concatenate all results from all nodes:
|
|
if preds is not None:
|
|
preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
|
|
if label_ids is not None:
|
|
label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
|
|
elif is_torch_tpu_available():
|
|
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
|
|
if preds is not None:
|
|
preds = nested_xla_mesh_reduce("eval_preds", preds)
|
|
if label_ids is not None:
|
|
label_ids = nested_xla_mesh_reduce("eval_label_ids", label_ids, torch.cat)
|
|
if eval_losses is not None:
|
|
eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()
|
|
|
|
# Finally, turn the aggregated tensors into numpy arrays.
|
|
if preds is not None:
|
|
preds = nested_numpify(preds)
|
|
if label_ids is not None:
|
|
label_ids = nested_numpify(label_ids)
|
|
|
|
if self.compute_metrics is not None and preds is not None and label_ids is not None:
|
|
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
|
|
else:
|
|
metrics = {}
|
|
if len(eval_losses) > 0:
|
|
if self.args.local_rank != -1:
|
|
metrics["eval_loss"] = (
|
|
distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader))
|
|
.mean()
|
|
.item()
|
|
)
|
|
else:
|
|
metrics["eval_loss"] = np.mean(eval_losses)
|
|
|
|
# Prefix all keys with eval_
|
|
for key in list(metrics.keys()):
|
|
if not key.startswith("eval_"):
|
|
metrics[f"eval_{key}"] = metrics.pop(key)
|
|
|
|
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
|
|
|
|
def prediction_step(
|
|
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
|
|
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
"""
|
|
Perform an evaluation step on :obj:`model` using obj:`inputs`.
|
|
|
|
Subclass and override to inject custom behavior.
|
|
|
|
Args:
|
|
model (:obj:`nn.Module`):
|
|
The model to evaluate.
|
|
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
|
The inputs and targets of the model.
|
|
|
|
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
|
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
|
prediction_loss_only (:obj:`bool`):
|
|
Whether or not to return the loss only.
|
|
|
|
Return:
|
|
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
A tuple with the loss, logits and labels (each being optional).
|
|
"""
|
|
has_labels = all(inputs.get(k) is not None for k in self.args.label_names)
|
|
inputs = self._prepare_inputs(inputs)
|
|
|
|
with torch.no_grad():
|
|
outputs = model(**inputs)
|
|
if has_labels:
|
|
# The .mean() is to reduce in case of distributed training
|
|
loss = outputs[0].mean().item()
|
|
logits = outputs[1:]
|
|
else:
|
|
loss = None
|
|
# Slicing so we get a tuple even if `outputs` is a `ModelOutput`.
|
|
logits = outputs[:]
|
|
if self.args.past_index >= 0:
|
|
self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
|
|
|
|
if prediction_loss_only:
|
|
return (loss, None, None)
|
|
|
|
logits = tuple(logit.detach() for logit in logits)
|
|
if len(logits) == 1:
|
|
logits = logits[0]
|
|
|
|
if has_labels:
|
|
labels = tuple(inputs.get(name).detach() for name in self.args.label_names)
|
|
if len(labels) == 1:
|
|
labels = labels[0]
|
|
else:
|
|
labels = None
|
|
|
|
return (loss, logits, labels)
|
|
|
|
def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
|
|
"""
|
|
For models that inherit from :class:`~transformers.PretrainedModel`, uses
|
|
that method to compute the number of floating point operations for every backward + forward pass. If using
|
|
another model, either implement such a method in the model or subclass and override this method.
|
|
|
|
Args:
|
|
model (:obj:`nn.Module`):
|
|
The model to evaluate.
|
|
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
|
The inputs and targets of the model.
|
|
|
|
Returns:
|
|
:obj:`int`: The number of floating-point operations.
|
|
"""
|
|
|
|
if isinstance(self.model, torch.nn.DataParallel) or isinstance(
|
|
self.model, torch.nn.parallel.DistributedDataParallel
|
|
):
|
|
model = self.model.module
|
|
else:
|
|
model = self.model
|
|
|
|
if hasattr(model, "floating_point_ops"):
|
|
return model.floating_point_ops(inputs)
|
|
|
|
else:
|
|
return 0
|