* gpt2 and t5 parallel modeling * model_parallel utils update * adding missing model_parallel_utils Adds missing model_parallel_utils and reverses the changes to code in modeling_gpt2 and modeling_t5 * training_args reformat Reformatted training_args * style formatting Style formatting doc string length on training_args and model_parallel_utils * style changes make style && make quality for training_args and model_parallel_utils. * adding tests * minor change in trainer reverts loss calculation * Update training_args.py * Update training_args.py added back docstring language for adam_beta1 and adam_beta2 * Update trainer.py * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Fix style & rebase Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr>
1549 lines
71 KiB
Python
Executable File
1549 lines
71 KiB
Python
Executable File
# coding=utf-8
|
|
# Copyright 2020-present the HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
|
|
"""
|
|
|
|
import collections
|
|
import inspect
|
|
import math
|
|
import os
|
|
import re
|
|
import shutil
|
|
import warnings
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
# Integrations must be imported before ML frameworks:
|
|
from .integrations import ( # isort: split
|
|
default_hp_search_backend,
|
|
hp_params,
|
|
is_azureml_available,
|
|
is_comet_available,
|
|
is_mlflow_available,
|
|
is_optuna_available,
|
|
is_ray_available,
|
|
is_tensorboard_available,
|
|
is_wandb_available,
|
|
run_hp_search_optuna,
|
|
run_hp_search_ray,
|
|
)
|
|
|
|
import numpy as np
|
|
import torch
|
|
from packaging import version
|
|
from torch import nn
|
|
from torch.utils.data.dataloader import DataLoader
|
|
from torch.utils.data.dataset import Dataset
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
|
|
|
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
|
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available
|
|
from .modeling_utils import PreTrainedModel
|
|
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
|
from .optimization import AdamW, get_linear_schedule_with_warmup
|
|
from .tokenization_utils_base import PreTrainedTokenizerBase
|
|
from .trainer_callback import (
|
|
CallbackHandler,
|
|
DefaultFlowCallback,
|
|
PrinterCallback,
|
|
ProgressCallback,
|
|
TrainerCallback,
|
|
TrainerControl,
|
|
TrainerState,
|
|
)
|
|
from .trainer_pt_utils import (
|
|
DistributedTensorGatherer,
|
|
SequentialDistributedSampler,
|
|
distributed_broadcast_scalars,
|
|
distributed_concat,
|
|
get_tpu_sampler,
|
|
nested_concat,
|
|
nested_detach,
|
|
nested_numpify,
|
|
nested_xla_mesh_reduce,
|
|
reissue_pt_warnings,
|
|
)
|
|
from .trainer_utils import (
|
|
PREFIX_CHECKPOINT_DIR,
|
|
BestRun,
|
|
EvalPrediction,
|
|
HPSearchBackend,
|
|
PredictionOutput,
|
|
TrainOutput,
|
|
default_compute_objective,
|
|
default_hp_space,
|
|
set_seed,
|
|
)
|
|
from .training_args import TrainingArguments
|
|
from .utils import logging
|
|
|
|
|
|
_use_native_amp = False
|
|
_use_apex = False
|
|
|
|
DEFAULT_CALLBACKS = [DefaultFlowCallback]
|
|
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
|
|
|
|
if is_in_notebook():
|
|
from .utils.notebook import NotebookProgressCallback
|
|
|
|
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
|
|
|
|
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
|
|
if version.parse(torch.__version__) < version.parse("1.6"):
|
|
from .file_utils import is_apex_available
|
|
|
|
if is_apex_available():
|
|
from apex import amp
|
|
_use_apex = True
|
|
else:
|
|
_use_native_amp = True
|
|
from torch.cuda.amp import autocast
|
|
|
|
if version.parse(torch.__version__) < version.parse("1.2"):
|
|
_use_ddp_no_sync = False
|
|
else:
|
|
_use_ddp_no_sync = True
|
|
|
|
if is_datasets_available():
|
|
import datasets
|
|
|
|
if is_torch_tpu_available():
|
|
import torch_xla.core.xla_model as xm
|
|
import torch_xla.debug.metrics as met
|
|
import torch_xla.distributed.parallel_loader as pl
|
|
|
|
if is_tensorboard_available():
|
|
from .integrations import TensorBoardCallback
|
|
|
|
DEFAULT_CALLBACKS.append(TensorBoardCallback)
|
|
|
|
|
|
if is_wandb_available():
|
|
from .integrations import WandbCallback
|
|
|
|
DEFAULT_CALLBACKS.append(WandbCallback)
|
|
|
|
if is_comet_available():
|
|
from .integrations import CometCallback
|
|
|
|
DEFAULT_CALLBACKS.append(CometCallback)
|
|
|
|
if is_mlflow_available():
|
|
from .integrations import MLflowCallback
|
|
|
|
DEFAULT_CALLBACKS.append(MLflowCallback)
|
|
|
|
if is_optuna_available():
|
|
import optuna
|
|
|
|
if is_ray_available():
|
|
from ray import tune
|
|
|
|
if is_azureml_available():
|
|
from .integrations import AzureMLCallback
|
|
|
|
DEFAULT_CALLBACKS.append(AzureMLCallback)
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class Trainer:
|
|
"""
|
|
Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
|
|
|
|
Args:
|
|
model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`, `optional`):
|
|
The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
|
|
|
|
.. note::
|
|
|
|
:class:`~transformers.Trainer` is optimized to work with the :class:`~transformers.PreTrainedModel`
|
|
provided by the library. You can still use your own models defined as :obj:`torch.nn.Module` as long as
|
|
they work the same way as the 🤗 Transformers models.
|
|
args (:class:`~transformers.TrainingArguments`, `optional`):
|
|
The arguments to tweak for training. Will default to a basic instance of
|
|
:class:`~transformers.TrainingArguments` with the ``output_dir`` set to a directory named `tmp_trainer` in
|
|
the current directory if not provided.
|
|
data_collator (:obj:`DataCollator`, `optional`):
|
|
The function to use to form a batch from a list of elements of :obj:`train_dataset` or :obj:`eval_dataset`.
|
|
Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is provided, an instance of
|
|
:func:`~transformers.DataCollatorWithPadding` otherwise.
|
|
train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
|
The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
|
``model.forward()`` method are automatically removed.
|
|
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
|
The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
|
``model.forward()`` method are automatically removed.
|
|
tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
|
|
The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
|
|
maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
|
|
interrupted training or reuse the fine-tuned model.
|
|
model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
|
|
A function that instantiates the model to be used. If provided, each call to
|
|
:meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
|
|
|
|
The function may have zero argument, or a single one containing the optuna/Ray Tune trial object, to be
|
|
able to choose different architectures according to hyper parameters (such as layer count, sizes of inner
|
|
layers, dropout probabilities etc).
|
|
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
|
|
The function that will be used to compute metrics at evaluation. Must take a
|
|
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
|
|
callbacks (List of :obj:`~transformers.TrainerCallback`, `optional`):
|
|
A list of callbacks to customize the training loop. Will add those to the list of default callbacks
|
|
detailed in :doc:`here <callback>`.
|
|
|
|
If you want to remove one of the default callbacks used, use the :meth:`Trainer.remove_callback` method.
|
|
optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`): A tuple
|
|
containing the optimizer and the scheduler to use. Will default to an instance of
|
|
:class:`~transformers.AdamW` on your model and a scheduler given by
|
|
:func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: Union[PreTrainedModel, torch.nn.Module] = None,
|
|
args: TrainingArguments = None,
|
|
data_collator: Optional[DataCollator] = None,
|
|
train_dataset: Optional[Dataset] = None,
|
|
eval_dataset: Optional[Dataset] = None,
|
|
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
|
model_init: Callable[[], PreTrainedModel] = None,
|
|
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
|
callbacks: Optional[List[TrainerCallback]] = None,
|
|
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
|
):
|
|
if args is None:
|
|
logger.info("No `TrainingArguments` passed, using the current path as `output_dir`.")
|
|
args = TrainingArguments("tmp_trainer")
|
|
self.args = args
|
|
# Seed must be set before instantiating the model when using model
|
|
set_seed(self.args.seed)
|
|
assert (
|
|
model is not None or model_init is not None
|
|
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
|
|
self.model_init = model_init
|
|
self.hp_name = None
|
|
if model is None and model_init is not None:
|
|
model = self.call_model_init()
|
|
# Model parallel
|
|
if not self.args.model_parallel:
|
|
self.model = model.to(args.device) if model is not None else None
|
|
else:
|
|
self.model = model if model is not None else None
|
|
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
|
self.data_collator = data_collator if data_collator is not None else default_collator
|
|
self.train_dataset = train_dataset
|
|
self.eval_dataset = eval_dataset
|
|
self.tokenizer = tokenizer
|
|
|
|
self.compute_metrics = compute_metrics
|
|
self.optimizer, self.lr_scheduler = optimizers
|
|
if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
|
|
raise RuntimeError(
|
|
"Passing a `model_init` is incompatible with providing the `optimizers` argument."
|
|
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
|
)
|
|
callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
|
|
self.callback_handler = CallbackHandler(callbacks, self.model, self.optimizer, self.lr_scheduler)
|
|
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
|
|
|
|
# Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
|
|
self._loggers_initialized = False
|
|
|
|
# Create output directory if needed
|
|
if self.is_world_process_zero():
|
|
os.makedirs(self.args.output_dir, exist_ok=True)
|
|
if is_torch_tpu_available() and isinstance(self.model, PreTrainedModel):
|
|
# Set an xla_device flag on the model's config.
|
|
# We'll find a more elegant and not need to do this in the future.
|
|
self.model.config.xla_device = True
|
|
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
|
|
raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
|
|
|
|
if args.max_steps > 0:
|
|
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
|
|
|
# Enforce rules on using datasets with no __len__
|
|
if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0:
|
|
raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")
|
|
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
|
|
raise ValueError("eval_dataset must implement __len__")
|
|
|
|
if is_datasets_available():
|
|
if isinstance(train_dataset, datasets.Dataset):
|
|
self._remove_unused_columns(self.train_dataset, description="training")
|
|
if isinstance(eval_dataset, datasets.Dataset):
|
|
self._remove_unused_columns(self.eval_dataset, description="evaluation")
|
|
|
|
self.state = TrainerState()
|
|
self.control = TrainerControl()
|
|
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
|
|
# state at each call to self.log.
|
|
self._total_flos = None
|
|
if self.args.fp16 and _use_native_amp:
|
|
self.scaler = torch.cuda.amp.GradScaler()
|
|
self.hp_search_backend = None
|
|
self.use_tune_checkpoints = False
|
|
default_label_names = (
|
|
["start_positions", "end_positions"]
|
|
if type(self.model) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values()
|
|
else ["labels"]
|
|
)
|
|
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
|
|
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
|
|
|
|
def add_callback(self, callback):
|
|
"""
|
|
Add a callback to the current list of :class:`~transformer.TrainerCallback`.
|
|
|
|
Args:
|
|
callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
|
|
A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
|
|
In the first case, will instantiate a member of that class.
|
|
"""
|
|
self.callback_handler.add_callback(callback)
|
|
|
|
def pop_callback(self, callback):
|
|
"""
|
|
Remove a callback from the current list of :class:`~transformer.TrainerCallback` and returns it.
|
|
|
|
If the callback is not found, returns :obj:`None` (and no error is raised).
|
|
|
|
Args:
|
|
callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
|
|
A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
|
|
In the first case, will pop the first member of that class found in the list of callbacks.
|
|
|
|
Returns:
|
|
:class:`~transformer.TrainerCallback`: The callback removed, if found.
|
|
"""
|
|
return self.callback_handler.pop_callback(callback)
|
|
|
|
def remove_callback(self, callback):
|
|
"""
|
|
Remove a callback from the current list of :class:`~transformer.TrainerCallback`.
|
|
|
|
Args:
|
|
callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
|
|
A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
|
|
In the first case, will remove the first member of that class found in the list of callbacks.
|
|
"""
|
|
self.callback_handler.remove_callback(callback)
|
|
|
|
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
|
|
if not self.args.remove_unused_columns:
|
|
return
|
|
# Inspect model forward signature to keep only the arguments it accepts.
|
|
signature = inspect.signature(self.model.forward)
|
|
signature_columns = list(signature.parameters.keys())
|
|
# Labels may be named label or label_ids, the default data collator handles that.
|
|
signature_columns += ["label", "label_ids"]
|
|
columns = [k for k in signature_columns if k in dataset.column_names]
|
|
ignored_columns = list(set(dataset.column_names) - set(signature_columns))
|
|
dset_description = "" if description is None else f"in the {description} set "
|
|
logger.info(
|
|
f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
|
|
)
|
|
dataset.set_format(type=dataset.format["type"], columns=columns)
|
|
|
|
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
|
if isinstance(self.train_dataset, torch.utils.data.IterableDataset) or not isinstance(
|
|
self.train_dataset, collections.abc.Sized
|
|
):
|
|
return None
|
|
elif is_torch_tpu_available():
|
|
return get_tpu_sampler(self.train_dataset)
|
|
else:
|
|
return (
|
|
RandomSampler(self.train_dataset)
|
|
if self.args.local_rank == -1
|
|
else DistributedSampler(self.train_dataset)
|
|
)
|
|
|
|
def get_train_dataloader(self) -> DataLoader:
|
|
"""
|
|
Returns the training :class:`~torch.utils.data.DataLoader`.
|
|
|
|
Will use no sampler if :obj:`self.train_dataset` does not implement :obj:`__len__`, a random sampler (adapted
|
|
to distributed training if necessary) otherwise.
|
|
|
|
Subclass and override this method if you want to inject some custom behavior.
|
|
"""
|
|
if self.train_dataset is None:
|
|
raise ValueError("Trainer: training requires a train_dataset.")
|
|
train_sampler = self._get_train_sampler()
|
|
|
|
return DataLoader(
|
|
self.train_dataset,
|
|
batch_size=self.args.train_batch_size,
|
|
sampler=train_sampler,
|
|
collate_fn=self.data_collator,
|
|
drop_last=self.args.dataloader_drop_last,
|
|
num_workers=self.args.dataloader_num_workers,
|
|
)
|
|
|
|
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
|
|
if is_torch_tpu_available():
|
|
return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
|
elif self.args.local_rank != -1:
|
|
return SequentialDistributedSampler(eval_dataset)
|
|
else:
|
|
return SequentialSampler(eval_dataset)
|
|
|
|
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
|
"""
|
|
Returns the evaluation :class:`~torch.utils.data.DataLoader`.
|
|
|
|
Subclass and override this method if you want to inject some custom behavior.
|
|
|
|
Args:
|
|
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
|
If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
|
|
accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
|
|
"""
|
|
if eval_dataset is None and self.eval_dataset is None:
|
|
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
|
elif eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
|
|
raise ValueError("eval_dataset must implement __len__")
|
|
elif is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
|
|
self._remove_unused_columns(eval_dataset, description="evaluation")
|
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
|
eval_sampler = self._get_eval_sampler(eval_dataset)
|
|
|
|
return DataLoader(
|
|
eval_dataset,
|
|
sampler=eval_sampler,
|
|
batch_size=self.args.eval_batch_size,
|
|
collate_fn=self.data_collator,
|
|
drop_last=self.args.dataloader_drop_last,
|
|
num_workers=self.args.dataloader_num_workers,
|
|
)
|
|
|
|
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
|
|
"""
|
|
Returns the test :class:`~torch.utils.data.DataLoader`.
|
|
|
|
Subclass and override this method if you want to inject some custom behavior.
|
|
|
|
Args:
|
|
test_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
|
The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
|
``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
|
|
"""
|
|
if not isinstance(test_dataset, collections.abc.Sized):
|
|
raise ValueError("test_dataset must implement __len__")
|
|
elif is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
|
|
self._remove_unused_columns(test_dataset, description="test")
|
|
test_sampler = self._get_eval_sampler(test_dataset)
|
|
|
|
# We use the same batch_size as for eval.
|
|
return DataLoader(
|
|
test_dataset,
|
|
sampler=test_sampler,
|
|
batch_size=self.args.eval_batch_size,
|
|
collate_fn=self.data_collator,
|
|
drop_last=self.args.dataloader_drop_last,
|
|
)
|
|
|
|
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
|
"""
|
|
Setup the optimizer and the learning rate scheduler.
|
|
|
|
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
|
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
|
|
"""
|
|
if self.optimizer is None:
|
|
no_decay = ["bias", "LayerNorm.weight"]
|
|
optimizer_grouped_parameters = [
|
|
{
|
|
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
|
|
"weight_decay": self.args.weight_decay,
|
|
},
|
|
{
|
|
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
|
|
"weight_decay": 0.0,
|
|
},
|
|
]
|
|
self.optimizer = AdamW(
|
|
optimizer_grouped_parameters,
|
|
lr=self.args.learning_rate,
|
|
betas=(self.args.adam_beta1, self.args.adam_beta2),
|
|
eps=self.args.adam_epsilon,
|
|
)
|
|
if self.lr_scheduler is None:
|
|
self.lr_scheduler = get_linear_schedule_with_warmup(
|
|
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
|
|
)
|
|
|
|
def num_examples(self, dataloader: DataLoader) -> int:
|
|
"""
|
|
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
|
|
|
|
Will raise an exception if the underlying dataset dese not implement method :obj:`__len__`
|
|
"""
|
|
return len(dataloader.dataset)
|
|
|
|
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
|
|
""" HP search setup code """
|
|
self._trial = trial
|
|
|
|
if self.hp_search_backend is None or trial is None:
|
|
return
|
|
|
|
params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
|
|
for key, value in params.items():
|
|
if not hasattr(self.args, key):
|
|
raise AttributeError(
|
|
f"Trying to set {key} in the hyperparameter search but there is no corresponding field in `TrainingArguments`."
|
|
)
|
|
old_attr = getattr(self.args, key, None)
|
|
# Casting value to the proper type
|
|
if old_attr is not None:
|
|
value = type(old_attr)(value)
|
|
setattr(self.args, key, value)
|
|
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
|
logger.info("Trial:", trial.params)
|
|
|
|
def _report_to_hp_search(
|
|
self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
|
|
):
|
|
if self.hp_search_backend is None or trial is None:
|
|
return
|
|
self.objective = self.compute_objective(metrics.copy())
|
|
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
|
trial.report(self.objective, epoch)
|
|
if trial.should_prune():
|
|
raise optuna.TrialPruned()
|
|
elif self.hp_search_backend == HPSearchBackend.RAY:
|
|
if self.state.global_step % self.args.save_steps == 0:
|
|
self._tune_save_checkpoint()
|
|
tune.report(objective=self.objective, **metrics)
|
|
|
|
def _tune_save_checkpoint(self):
|
|
if not self.use_tune_checkpoints:
|
|
return
|
|
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
|
|
self.args.output_dir = checkpoint_dir
|
|
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
|
|
self.save_model(output_dir)
|
|
if self.is_world_master():
|
|
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
|
|
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
|
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
|
|
|
def call_model_init(self, trial=None):
|
|
model_init_argcount = len(inspect.signature(self.model_init).parameters)
|
|
if model_init_argcount == 0:
|
|
model = self.model_init()
|
|
elif model_init_argcount == 1:
|
|
model = self.model_init(trial)
|
|
else:
|
|
raise RuntimeError("model_init should have 0 or 1 argument.")
|
|
|
|
if model is None:
|
|
raise RuntimeError("model_init should not return None.")
|
|
|
|
return model
|
|
|
|
def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
|
|
"""
|
|
Main training entry point.
|
|
|
|
Args:
|
|
model_path (:obj:`str`, `optional`):
|
|
Local path to the model if the model to train has been instantiated from a local path. If present,
|
|
training will resume from the optimizer/scheduler states loaded here.
|
|
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
|
|
The trial run or the hyperparameter dictionary for hyperparameter search.
|
|
"""
|
|
# This might change the seed so needs to run first.
|
|
self._hp_search_setup(trial)
|
|
|
|
# Model re-init
|
|
if self.model_init is not None:
|
|
# Seed must be set before instantiating the model when using model_init.
|
|
set_seed(self.args.seed)
|
|
|
|
model = self.call_model_init(trial)
|
|
|
|
if not self.args.model_parallel:
|
|
self.model = model.to(self.args.device)
|
|
|
|
# Reinitializes optimizer and scheduler
|
|
self.optimizer, self.lr_scheduler = None, None
|
|
|
|
# Keeping track whether we can can len() on the dataset or not
|
|
train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized)
|
|
|
|
# Data loader and number of training steps
|
|
train_dataloader = self.get_train_dataloader()
|
|
|
|
# Setting up training control variables:
|
|
# number of training epochs: num_train_epochs
|
|
# number of training steps per epoch: num_update_steps_per_epoch
|
|
# total number of training steps to execute: max_steps
|
|
if train_dataset_is_sized:
|
|
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
|
|
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
|
|
if self.args.max_steps > 0:
|
|
max_steps = self.args.max_steps
|
|
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
|
|
self.args.max_steps % num_update_steps_per_epoch > 0
|
|
)
|
|
else:
|
|
max_steps = math.ceil(self.args.num_train_epochs * num_update_steps_per_epoch)
|
|
num_train_epochs = math.ceil(self.args.num_train_epochs)
|
|
else:
|
|
# see __init__. max_steps is set when the dataset has no __len__
|
|
max_steps = self.args.max_steps
|
|
num_train_epochs = 1
|
|
num_update_steps_per_epoch = max_steps
|
|
|
|
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
|
self.state = TrainerState()
|
|
self.state.is_hyper_param_search = trial is not None
|
|
|
|
# Check if saved optimizer or scheduler states exist
|
|
self._load_optimizer_and_scheduler(model_path)
|
|
|
|
# Mixed precision training with apex (torch < 1.6)
|
|
model = self.model
|
|
if self.args.fp16 and _use_apex:
|
|
if not is_apex_available():
|
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
|
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
|
|
|
# Multi-gpu training (should be after apex fp16 initialization)
|
|
if self.args.n_gpu > 1 and not self.args.model_parallel:
|
|
model = torch.nn.DataParallel(model)
|
|
|
|
# Distributed training (should be after apex fp16 initialization)
|
|
if self.args.local_rank != -1:
|
|
model = torch.nn.parallel.DistributedDataParallel(
|
|
model,
|
|
device_ids=[self.args.local_rank],
|
|
output_device=self.args.local_rank,
|
|
find_unused_parameters=(
|
|
not getattr(model.config, "gradient_checkpointing", False)
|
|
if isinstance(model, PreTrainedModel)
|
|
else True
|
|
),
|
|
)
|
|
# find_unused_parameters breaks checkpointing as per
|
|
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
|
|
|
|
# Train!
|
|
if is_torch_tpu_available():
|
|
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
|
|
else:
|
|
total_train_batch_size = (
|
|
self.args.train_batch_size
|
|
* self.args.gradient_accumulation_steps
|
|
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
|
|
)
|
|
|
|
num_examples = (
|
|
self.num_examples(train_dataloader)
|
|
if train_dataset_is_sized
|
|
else total_train_batch_size * self.args.max_steps
|
|
)
|
|
|
|
logger.info("***** Running training *****")
|
|
logger.info(" Num examples = %d", num_examples)
|
|
logger.info(" Num Epochs = %d", num_train_epochs)
|
|
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
|
|
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
|
|
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
|
|
logger.info(" Total optimization steps = %d", max_steps)
|
|
|
|
self.state.epoch = 0
|
|
epochs_trained = 0
|
|
steps_trained_in_current_epoch = 0
|
|
|
|
# Check if continuing training from a checkpoint
|
|
if model_path and os.path.isfile(os.path.join(model_path, "trainer_state.json")):
|
|
self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json"))
|
|
epochs_trained = self.state.global_step // num_update_steps_per_epoch
|
|
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
|
|
steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps
|
|
|
|
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
|
logger.info(" Continuing training from epoch %d", epochs_trained)
|
|
logger.info(" Continuing training from global step %d", self.state.global_step)
|
|
logger.info(" Will skip the first %d batches in the first epoch", steps_trained_in_current_epoch)
|
|
|
|
# Update the references
|
|
self.callback_handler.model = self.model
|
|
self.callback_handler.optimizer = self.optimizer
|
|
self.callback_handler.lr_scheduler = self.lr_scheduler
|
|
self.callback_handler.train_dataloader = train_dataloader
|
|
self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None
|
|
self.state.trial_params = hp_params(trial) if trial is not None else None
|
|
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
|
|
# to set this after the load.
|
|
self.state.max_steps = max_steps
|
|
self.state.num_train_epochs = num_train_epochs
|
|
self.state.is_local_process_zero = self.is_local_process_zero()
|
|
self.state.is_world_process_zero = self.is_world_process_zero()
|
|
|
|
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
|
|
tr_loss = torch.tensor(0.0).to(self.args.device)
|
|
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
|
|
self._total_loss_scalar = 0.0
|
|
self._globalstep_last_logged = 0
|
|
self._total_flos = self.state.total_flos
|
|
model.zero_grad()
|
|
|
|
self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control)
|
|
|
|
for epoch in range(epochs_trained, num_train_epochs):
|
|
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
|
train_dataloader.sampler.set_epoch(epoch)
|
|
|
|
if is_torch_tpu_available():
|
|
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
|
|
self.args.device
|
|
)
|
|
epoch_iterator = parallel_loader
|
|
else:
|
|
epoch_iterator = train_dataloader
|
|
|
|
# Reset the past mems state at the beginning of each epoch if necessary.
|
|
if self.args.past_index >= 0:
|
|
self._past = None
|
|
|
|
steps_in_epoch = len(epoch_iterator) if train_dataset_is_sized else self.args.max_steps
|
|
self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)
|
|
|
|
for step, inputs in enumerate(epoch_iterator):
|
|
|
|
# Skip past any already trained steps if resuming training
|
|
if steps_trained_in_current_epoch > 0:
|
|
steps_trained_in_current_epoch -= 1
|
|
continue
|
|
|
|
if (step + 1) % self.args.gradient_accumulation_steps == 0:
|
|
self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)
|
|
|
|
if (
|
|
((step + 1) % self.args.gradient_accumulation_steps != 0)
|
|
and self.args.local_rank != -1
|
|
and _use_ddp_no_sync
|
|
):
|
|
with model.no_sync():
|
|
tr_loss += self.training_step(model, inputs)
|
|
else:
|
|
tr_loss += self.training_step(model, inputs)
|
|
self._total_flos += self.floating_point_ops(inputs)
|
|
|
|
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
|
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
|
steps_in_epoch <= self.args.gradient_accumulation_steps
|
|
and (step + 1) == steps_in_epoch
|
|
):
|
|
if self.args.fp16 and _use_native_amp:
|
|
self.scaler.unscale_(self.optimizer)
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
|
|
elif self.args.fp16 and _use_apex:
|
|
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
|
|
else:
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
|
|
|
|
if is_torch_tpu_available():
|
|
xm.optimizer_step(self.optimizer)
|
|
elif self.args.fp16 and _use_native_amp:
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
else:
|
|
self.optimizer.step()
|
|
|
|
self.lr_scheduler.step()
|
|
model.zero_grad()
|
|
self.state.global_step += 1
|
|
self.state.epoch = epoch + (step + 1) / steps_in_epoch
|
|
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
|
|
|
|
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
|
|
|
|
if self.control.should_epoch_stop or self.control.should_training_stop:
|
|
break
|
|
|
|
self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
|
|
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)
|
|
|
|
if self.args.tpu_metrics_debug or self.args.debug:
|
|
if is_torch_tpu_available():
|
|
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
|
xm.master_print(met.metrics_report())
|
|
else:
|
|
logger.warning(
|
|
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
|
|
"configured. Check your training configuration if this is unexpected."
|
|
)
|
|
if self.control.should_training_stop:
|
|
break
|
|
|
|
if self.args.past_index and hasattr(self, "_past"):
|
|
# Clean the state at the end of training
|
|
delattr(self, "_past")
|
|
|
|
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
|
|
if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
|
|
logger.info(
|
|
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
|
|
)
|
|
if isinstance(model, PreTrainedModel):
|
|
self.model = model.from_pretrained(self.state.best_model_checkpoint)
|
|
if not self.args.model_parallel:
|
|
self.model = self.model.to(self.args.device)
|
|
else:
|
|
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
|
|
self.model.load_state_dict(state_dict)
|
|
|
|
if self._total_flos is not None:
|
|
self.store_flos()
|
|
self.log({"total_flos": self.state.total_flos})
|
|
|
|
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
|
|
# add remaining tr_loss
|
|
self._total_loss_scalar += tr_loss.item()
|
|
|
|
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step)
|
|
|
|
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
|
|
if self.control.should_log:
|
|
logs: Dict[str, float] = {}
|
|
tr_loss_scalar = tr_loss.item()
|
|
# reset tr_loss to zero
|
|
tr_loss -= tr_loss
|
|
|
|
logs["loss"] = tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged)
|
|
# backward compatibility for pytorch schedulers
|
|
logs["learning_rate"] = (
|
|
self.lr_scheduler.get_last_lr()[0]
|
|
if version.parse(torch.__version__) >= version.parse("1.4")
|
|
else self.lr_scheduler.get_lr()[0]
|
|
)
|
|
self._total_loss_scalar += tr_loss_scalar
|
|
self._globalstep_last_logged = self.state.global_step
|
|
|
|
self.log(logs)
|
|
|
|
metrics = None
|
|
if self.control.should_evaluate:
|
|
metrics = self.evaluate()
|
|
self._report_to_hp_search(trial, epoch, metrics)
|
|
|
|
if self.control.should_save:
|
|
self._save_checkpoint(model, trial, metrics=metrics)
|
|
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
|
|
|
def _save_checkpoint(self, model, trial, metrics=None):
|
|
# In all cases (even distributed/parallel), self.model is always a reference
|
|
# to the model we want to save.
|
|
if hasattr(model, "module"):
|
|
assert model.module is self.model, f"Module {model.module} should be a reference to self.model"
|
|
else:
|
|
assert model is self.model, f"Model {model} should be a reference to self.model"
|
|
# Save model checkpoint
|
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
|
|
|
if self.hp_search_backend is not None and trial is not None:
|
|
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
|
|
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
|
|
output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder)
|
|
else:
|
|
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
|
|
|
|
self.store_flos()
|
|
self.save_model(output_dir)
|
|
|
|
# Save optimizer and scheduler
|
|
if is_torch_tpu_available():
|
|
xm.rendezvous("saving_optimizer_states")
|
|
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
|
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
|
reissue_pt_warnings(caught_warnings)
|
|
elif self.is_world_process_zero():
|
|
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
|
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
|
reissue_pt_warnings(caught_warnings)
|
|
|
|
# Determine the new best metric / best model checkpoint
|
|
if metrics is not None and self.args.metric_for_best_model is not None:
|
|
metric_to_check = self.args.metric_for_best_model
|
|
if not metric_to_check.startswith("eval_"):
|
|
metric_to_check = f"eval_{metric_to_check}"
|
|
metric_value = metrics[metric_to_check]
|
|
|
|
operator = np.greater if self.args.greater_is_better else np.less
|
|
if (
|
|
self.state.best_metric is None
|
|
or self.state.best_model_checkpoint is None
|
|
or operator(metric_value, self.state.best_metric)
|
|
):
|
|
self.state.best_metric = metric_value
|
|
self.state.best_model_checkpoint = output_dir
|
|
|
|
# Save the Trainer state
|
|
if self.is_world_process_zero():
|
|
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
|
|
|
|
# Maybe delete some older checkpoints.
|
|
if self.is_world_process_zero():
|
|
self._rotate_checkpoints(use_mtime=True)
|
|
|
|
def _load_optimizer_and_scheduler(self, model_path):
|
|
"""If optimizer and scheduler states exist, load them."""
|
|
if (
|
|
model_path is not None
|
|
and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
|
|
and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
|
|
):
|
|
# Load in optimizer and scheduler states
|
|
if is_torch_tpu_available():
|
|
# On TPU we have to take some extra precautions to properly load the states on the right device.
|
|
optimizer_state = torch.load(os.path.join(model_path, "optimizer.pt"), map_location="cpu")
|
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
|
lr_scheduler_state = torch.load(os.path.join(model_path, "scheduler.pt"), map_location="cpu")
|
|
reissue_pt_warnings(caught_warnings)
|
|
|
|
xm.send_cpu_data_to_device(optimizer_state, self.args.device)
|
|
xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)
|
|
|
|
self.optimizer.load_state_dict(optimizer_state)
|
|
self.lr_scheduler.load_state_dict(lr_scheduler_state)
|
|
else:
|
|
self.optimizer.load_state_dict(
|
|
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
|
|
)
|
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
|
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
|
reissue_pt_warnings(caught_warnings)
|
|
|
|
def hyperparameter_search(
|
|
self,
|
|
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
|
|
compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
|
|
n_trials: int = 20,
|
|
direction: str = "minimize",
|
|
backend: Optional[Union["str", HPSearchBackend]] = None,
|
|
hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
|
|
**kwargs
|
|
) -> BestRun:
|
|
"""
|
|
Launch an hyperparameter search using ``optuna`` or ``Ray Tune``. The optimized quantity is determined by
|
|
:obj:`compute_objectie`, which defaults to a function returning the evaluation loss when no metric is provided,
|
|
the sum of all metrics otherwise.
|
|
|
|
.. warning::
|
|
|
|
To use this method, you need to have provided a ``model_init`` when initializing your
|
|
:class:`~transformers.Trainer`: we need to reinitialize the model at each new run. This is incompatible
|
|
with the ``optimizers`` argument, so you need to subclass :class:`~transformers.Trainer` and override the
|
|
method :meth:`~transformers.Trainer.create_optimizer_and_scheduler` for custom optimizer/scheduler.
|
|
|
|
Args:
|
|
hp_space (:obj:`Callable[["optuna.Trial"], Dict[str, float]]`, `optional`):
|
|
A function that defines the hyperparameter search space. Will default to
|
|
:func:`~transformers.trainer_utils.default_hp_space_optuna` or
|
|
:func:`~transformers.trainer_utils.default_hp_space_ray` depending on your backend.
|
|
compute_objective (:obj:`Callable[[Dict[str, float]], float]`, `optional`):
|
|
A function computing the objective to minimize or maximize from the metrics returned by the
|
|
:obj:`evaluate` method. Will default to :func:`~transformers.trainer_utils.default_compute_objective`.
|
|
n_trials (:obj:`int`, `optional`, defaults to 100):
|
|
The number of trial runs to test.
|
|
direction(:obj:`str`, `optional`, defaults to :obj:`"minimize"`):
|
|
Whether to optimize greater or lower objects. Can be :obj:`"minimize"` or :obj:`"maximize"`, you should
|
|
pick :obj:`"minimize"` when optimizing the validation loss, :obj:`"maximize"` when optimizing one or
|
|
several metrics.
|
|
backend(:obj:`str` or :class:`~transformers.training_utils.HPSearchBackend`, `optional`):
|
|
The backend to use for hyperparameter search. Will default to optuna or Ray Tune, depending on which
|
|
one is installed. If both are installed, will default to optuna.
|
|
kwargs:
|
|
Additional keyword arguments passed along to :obj:`optuna.create_study` or :obj:`ray.tune.run`. For
|
|
more information see:
|
|
|
|
- the documentation of `optuna.create_study
|
|
<https://optuna.readthedocs.io/en/stable/reference/alias_generated/optuna.create_study.html#optuna.create_study>`__
|
|
- the documentation of `tune.run
|
|
<https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run>`__
|
|
|
|
Returns:
|
|
:class:`transformers.trainer_utils.BestRun`: All the information about the best run.
|
|
"""
|
|
if backend is None:
|
|
backend = default_hp_search_backend()
|
|
if backend is None:
|
|
raise RuntimeError(
|
|
"At least one of optuna or ray should be installed. "
|
|
"To install optuna run `pip install optuna`."
|
|
"To install ray run `pip install ray[tune]`."
|
|
)
|
|
backend = HPSearchBackend(backend)
|
|
if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
|
|
raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
|
|
if backend == HPSearchBackend.RAY and not is_ray_available():
|
|
raise RuntimeError(
|
|
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
|
|
)
|
|
self.hp_search_backend = backend
|
|
if self.model_init is None:
|
|
raise RuntimeError(
|
|
"To use hyperparameter search, you need to pass your model through a model_init function."
|
|
)
|
|
|
|
self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
|
|
self.hp_name = hp_name
|
|
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
|
|
|
|
run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
|
|
best_run = run_hp_search(self, n_trials, direction, **kwargs)
|
|
|
|
self.hp_search_backend = None
|
|
return best_run
|
|
|
|
def log(self, logs: Dict[str, float]) -> None:
|
|
"""
|
|
Log :obj:`logs` on the various objects watching training.
|
|
|
|
Subclass and override this method to inject custom behavior.
|
|
|
|
Args:
|
|
logs (:obj:`Dict[str, float]`):
|
|
The values to log.
|
|
"""
|
|
if self.state.epoch is not None:
|
|
logs["epoch"] = self.state.epoch
|
|
|
|
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
|
|
output = {**logs, **{"step": self.state.global_step}}
|
|
self.state.log_history.append(output)
|
|
|
|
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
|
|
"""
|
|
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
|
|
handling potential state.
|
|
"""
|
|
for k, v in inputs.items():
|
|
if isinstance(v, torch.Tensor):
|
|
inputs[k] = v.to(self.args.device)
|
|
|
|
if self.args.past_index >= 0 and self._past is not None:
|
|
inputs["mems"] = self._past
|
|
|
|
return inputs
|
|
|
|
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
|
|
"""
|
|
Perform a training step on a batch of inputs.
|
|
|
|
Subclass and override to inject custom behavior.
|
|
|
|
Args:
|
|
model (:obj:`nn.Module`):
|
|
The model to train.
|
|
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
|
The inputs and targets of the model.
|
|
|
|
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
|
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
|
|
|
Return:
|
|
:obj:`torch.Tensor`: The tensor with training loss on this batch.
|
|
"""
|
|
|
|
model.train()
|
|
inputs = self._prepare_inputs(inputs)
|
|
|
|
if self.args.fp16 and _use_native_amp:
|
|
with autocast():
|
|
loss = self.compute_loss(model, inputs)
|
|
else:
|
|
loss = self.compute_loss(model, inputs)
|
|
|
|
if self.args.n_gpu > 1:
|
|
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
|
|
|
if self.args.gradient_accumulation_steps > 1:
|
|
loss = loss / self.args.gradient_accumulation_steps
|
|
|
|
if self.args.fp16 and _use_native_amp:
|
|
self.scaler.scale(loss).backward()
|
|
elif self.args.fp16 and _use_apex:
|
|
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
|
scaled_loss.backward()
|
|
else:
|
|
loss.backward()
|
|
|
|
return loss.detach()
|
|
|
|
def compute_loss(self, model, inputs):
|
|
"""
|
|
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
|
|
|
Subclass and override for custom behavior.
|
|
"""
|
|
outputs = model(**inputs)
|
|
# Save past state if it exists
|
|
# TODO: this needs to be fixed and made cleaner later.
|
|
if self.args.past_index >= 0:
|
|
self._past = outputs[self.args.past_index]
|
|
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
|
return outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
|
|
|
def is_local_process_zero(self) -> bool:
|
|
"""
|
|
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
|
|
machines) main process.
|
|
"""
|
|
if is_torch_tpu_available():
|
|
return xm.is_master_ordinal(local=True)
|
|
else:
|
|
return self.args.local_rank in [-1, 0]
|
|
|
|
def is_world_process_zero(self) -> bool:
|
|
"""
|
|
Whether or not this process is the global main process (when training in a distributed fashion on several
|
|
machines, this is only going to be :obj:`True` for one process).
|
|
"""
|
|
if is_torch_tpu_available():
|
|
return xm.is_master_ordinal(local=False)
|
|
else:
|
|
return self.args.local_rank == -1 or torch.distributed.get_rank() == 0
|
|
|
|
def save_model(self, output_dir: Optional[str] = None):
|
|
"""
|
|
Will save the model, so you can reload it using :obj:`from_pretrained()`.
|
|
|
|
Will only save from the world_master process (unless in TPUs).
|
|
"""
|
|
|
|
if is_torch_tpu_available():
|
|
self._save_tpu(output_dir)
|
|
elif self.is_world_process_zero():
|
|
self._save(output_dir)
|
|
|
|
def _save_tpu(self, output_dir: Optional[str] = None):
|
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
|
logger.info("Saving model checkpoint to %s", output_dir)
|
|
|
|
if xm.is_master_ordinal():
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
|
|
|
# Save a trained model and configuration using `save_pretrained()`.
|
|
# They can then be reloaded using `from_pretrained()`
|
|
xm.rendezvous("saving_checkpoint")
|
|
if not isinstance(self.model, PreTrainedModel):
|
|
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
|
state_dict = self.model.state_dict()
|
|
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
|
else:
|
|
self.model.save_pretrained(output_dir)
|
|
if self.tokenizer is not None and self.is_world_process_zero():
|
|
self.tokenizer.save_pretrained(output_dir)
|
|
|
|
def _save(self, output_dir: Optional[str] = None):
|
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
logger.info("Saving model checkpoint to %s", output_dir)
|
|
# Save a trained model and configuration using `save_pretrained()`.
|
|
# They can then be reloaded using `from_pretrained()`
|
|
if not isinstance(self.model, PreTrainedModel):
|
|
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
|
state_dict = self.model.state_dict()
|
|
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
|
else:
|
|
self.model.save_pretrained(output_dir)
|
|
if self.tokenizer is not None and self.is_world_process_zero():
|
|
self.tokenizer.save_pretrained(output_dir)
|
|
|
|
# Good practice: save your training arguments together with the trained model
|
|
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
|
|
|
def store_flos(self):
|
|
# Storing the number of floating-point operations that went into the model
|
|
if self._total_flos is not None:
|
|
if self.args.local_rank != -1:
|
|
self.state.total_flos = distributed_broadcast_scalars([self._total_flos]).sum().item()
|
|
else:
|
|
self.state.total_flos = self._total_flos
|
|
|
|
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
|
|
ordering_and_checkpoint_path = []
|
|
|
|
glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")]
|
|
|
|
for path in glob_checkpoints:
|
|
if use_mtime:
|
|
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
|
|
else:
|
|
regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
|
|
if regex_match and regex_match.groups():
|
|
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
|
|
|
|
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
|
|
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
|
|
# Make sure we don't delete the best model.
|
|
if self.state.best_model_checkpoint is not None:
|
|
best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
|
|
checkpoints_sorted[best_model_index], checkpoints_sorted[-1] = (
|
|
checkpoints_sorted[-1],
|
|
checkpoints_sorted[best_model_index],
|
|
)
|
|
return checkpoints_sorted
|
|
|
|
def _rotate_checkpoints(self, use_mtime=False) -> None:
|
|
if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
|
|
return
|
|
|
|
# Check if we should delete older checkpoint(s)
|
|
checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
|
|
if len(checkpoints_sorted) <= self.args.save_total_limit:
|
|
return
|
|
|
|
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit)
|
|
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
|
|
for checkpoint in checkpoints_to_be_deleted:
|
|
logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
|
|
shutil.rmtree(checkpoint)
|
|
|
|
def evaluate(
|
|
self, eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Run evaluation and returns metrics.
|
|
|
|
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
|
|
(pass it to the init :obj:`compute_metrics` argument).
|
|
|
|
You can also subclass and override this method to inject custom behavior.
|
|
|
|
Args:
|
|
eval_dataset (:obj:`Dataset`, `optional`):
|
|
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
|
|
columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the
|
|
:obj:`__len__` method.
|
|
ignore_keys (:obj:`Lst[str]`, `optional`):
|
|
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
|
gathering predictions.
|
|
|
|
Returns:
|
|
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
|
|
dictionary also contains the epoch number which comes from the training state.
|
|
"""
|
|
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
|
|
raise ValueError("eval_dataset must implement __len__")
|
|
|
|
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
|
|
|
output = self.prediction_loop(
|
|
eval_dataloader,
|
|
description="Evaluation",
|
|
# No point gathering the predictions if there are no metrics, otherwise we defer to
|
|
# self.args.prediction_loss_only
|
|
prediction_loss_only=True if self.compute_metrics is None else None,
|
|
ignore_keys=ignore_keys,
|
|
)
|
|
|
|
self.log(output.metrics)
|
|
|
|
if self.args.tpu_metrics_debug or self.args.debug:
|
|
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
|
xm.master_print(met.metrics_report())
|
|
|
|
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
|
|
return output.metrics
|
|
|
|
def predict(self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None) -> PredictionOutput:
|
|
"""
|
|
Run prediction and returns predictions and potential metrics.
|
|
|
|
Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
|
|
will also return metrics, like in :obj:`evaluate()`.
|
|
|
|
Args:
|
|
test_dataset (:obj:`Dataset`):
|
|
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
|
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
|
|
ignore_keys (:obj:`Lst[str]`, `optional`):
|
|
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
|
gathering predictions.
|
|
|
|
.. note::
|
|
|
|
If your predictions or labels have different sequence length (for instance because you're doing dynamic
|
|
padding in a token classification task) the predictions will be padded (on the right) to allow for
|
|
concatenation into one array. The padding index is -100.
|
|
|
|
Returns: `NamedTuple` A namedtuple with the following keys:
|
|
|
|
- predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`.
|
|
- label_ids (:obj:`np.ndarray`, `optional`): The labels (if the dataset contained some).
|
|
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
|
|
contained labels).
|
|
"""
|
|
if test_dataset is not None and not isinstance(test_dataset, collections.abc.Sized):
|
|
raise ValueError("test_dataset must implement __len__")
|
|
|
|
test_dataloader = self.get_test_dataloader(test_dataset)
|
|
|
|
return self.prediction_loop(test_dataloader, description="Prediction", ignore_keys=ignore_keys)
|
|
|
|
def prediction_loop(
|
|
self,
|
|
dataloader: DataLoader,
|
|
description: str,
|
|
prediction_loss_only: Optional[bool] = None,
|
|
ignore_keys: Optional[List[str]] = None,
|
|
) -> PredictionOutput:
|
|
"""
|
|
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
|
|
|
|
Works both with or without labels.
|
|
"""
|
|
if not isinstance(dataloader.dataset, collections.abc.Sized):
|
|
raise ValueError("dataset must implement __len__")
|
|
prediction_loss_only = (
|
|
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
|
|
)
|
|
|
|
model = self.model
|
|
# multi-gpu eval
|
|
if self.args.n_gpu > 1 and not self.args.model_parallel:
|
|
model = torch.nn.DataParallel(model)
|
|
# Note: in torch.distributed mode, there's no point in wrapping the model
|
|
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
|
|
|
|
batch_size = dataloader.batch_size
|
|
num_examples = self.num_examples(dataloader)
|
|
logger.info("***** Running %s *****", description)
|
|
logger.info(" Num examples = %d", num_examples)
|
|
logger.info(" Batch size = %d", batch_size)
|
|
losses_host: torch.Tensor = None
|
|
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
|
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
|
|
|
world_size = 1
|
|
if is_torch_tpu_available():
|
|
world_size = xm.xrt_world_size()
|
|
elif self.args.local_rank != -1:
|
|
world_size = torch.distributed.get_world_size()
|
|
world_size = max(1, world_size)
|
|
|
|
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
|
|
if not prediction_loss_only:
|
|
preds_gatherer = DistributedTensorGatherer(world_size, num_examples)
|
|
labels_gatherer = DistributedTensorGatherer(world_size, num_examples)
|
|
|
|
model.eval()
|
|
|
|
if is_torch_tpu_available():
|
|
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
|
|
|
|
if self.args.past_index >= 0:
|
|
self._past = None
|
|
|
|
self.callback_handler.eval_dataloader = dataloader
|
|
|
|
for step, inputs in enumerate(dataloader):
|
|
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
|
|
if loss is not None:
|
|
losses = loss.repeat(batch_size)
|
|
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
|
|
if logits is not None:
|
|
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
|
|
if labels is not None:
|
|
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
|
|
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
|
|
|
|
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
|
|
if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
|
|
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
|
|
if not prediction_loss_only:
|
|
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
|
|
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
|
|
|
|
# Set back to None to begin a new accumulation
|
|
losses_host, preds_host, labels_host = None, None, None
|
|
|
|
if self.args.past_index and hasattr(self, "_past"):
|
|
# Clean the state at the end of the evaluation loop
|
|
delattr(self, "_past")
|
|
|
|
# Gather all remaining tensors and put them back on the CPU
|
|
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
|
|
if not prediction_loss_only:
|
|
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
|
|
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
|
|
|
|
eval_loss = eval_losses_gatherer.finalize()
|
|
preds = preds_gatherer.finalize() if not prediction_loss_only else None
|
|
label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
|
|
|
|
if self.compute_metrics is not None and preds is not None and label_ids is not None:
|
|
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
|
|
else:
|
|
metrics = {}
|
|
|
|
if eval_loss is not None:
|
|
metrics["eval_loss"] = eval_loss.mean().item()
|
|
|
|
# Prefix all keys with eval_
|
|
for key in list(metrics.keys()):
|
|
if not key.startswith("eval_"):
|
|
metrics[f"eval_{key}"] = metrics.pop(key)
|
|
|
|
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
|
|
|
|
def _gather_and_numpify(self, tensors, name):
|
|
"""
|
|
Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
|
|
concatenating them to `gathered`
|
|
"""
|
|
if tensors is None:
|
|
return
|
|
if is_torch_tpu_available():
|
|
tensors = nested_xla_mesh_reduce(tensors, name)
|
|
elif self.args.local_rank != -1:
|
|
tensors = distributed_concat(tensors)
|
|
|
|
return nested_numpify(tensors)
|
|
|
|
def prediction_step(
|
|
self,
|
|
model: nn.Module,
|
|
inputs: Dict[str, Union[torch.Tensor, Any]],
|
|
prediction_loss_only: bool,
|
|
ignore_keys: Optional[List[str]] = None,
|
|
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
"""
|
|
Perform an evaluation step on :obj:`model` using obj:`inputs`.
|
|
|
|
Subclass and override to inject custom behavior.
|
|
|
|
Args:
|
|
model (:obj:`nn.Module`):
|
|
The model to evaluate.
|
|
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
|
The inputs and targets of the model.
|
|
|
|
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
|
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
|
prediction_loss_only (:obj:`bool`):
|
|
Whether or not to return the loss only.
|
|
ignore_keys (:obj:`Lst[str]`, `optional`):
|
|
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
|
gathering predictions.
|
|
|
|
Return:
|
|
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
|
|
labels (each being optional).
|
|
"""
|
|
has_labels = all(inputs.get(k) is not None for k in self.label_names)
|
|
inputs = self._prepare_inputs(inputs)
|
|
if ignore_keys is None:
|
|
if hasattr(self.model, "config"):
|
|
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
|
|
else:
|
|
ignore_keys = []
|
|
|
|
with torch.no_grad():
|
|
if self.args.fp16 and _use_native_amp:
|
|
with autocast():
|
|
outputs = model(**inputs)
|
|
else:
|
|
outputs = model(**inputs)
|
|
if has_labels:
|
|
if isinstance(outputs, dict):
|
|
loss = outputs["loss"].mean().detach()
|
|
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
|
|
else:
|
|
loss = outputs[0].mean().detach()
|
|
logits = outputs[1:]
|
|
else:
|
|
loss = None
|
|
if isinstance(outputs, dict):
|
|
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
|
|
else:
|
|
logits = outputs
|
|
# TODO: this needs to be fixed and made cleaner later.
|
|
if self.args.past_index >= 0:
|
|
self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
|
|
|
|
if prediction_loss_only:
|
|
return (loss, None, None)
|
|
|
|
logits = nested_detach(logits)
|
|
if len(logits) == 1:
|
|
logits = logits[0]
|
|
|
|
if has_labels:
|
|
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
|
|
if len(labels) == 1:
|
|
labels = labels[0]
|
|
else:
|
|
labels = None
|
|
|
|
return (loss, logits, labels)
|
|
|
|
def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
|
|
"""
|
|
For models that inherit from :class:`~transformers.PreTrainedModel`, uses that method to compute the number of
|
|
floating point operations for every backward + forward pass. If using another model, either implement such a
|
|
method in the model or subclass and override this method.
|
|
|
|
Args:
|
|
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
|
The inputs and targets of the model.
|
|
|
|
Returns:
|
|
:obj:`int`: The number of floating-point operations.
|
|
"""
|
|
|
|
model = self._actual_model(self.model)
|
|
|
|
if hasattr(model, "floating_point_ops"):
|
|
return model.floating_point_ops(inputs)
|
|
|
|
else:
|
|
return 0
|
|
|
|
@staticmethod
|
|
def _actual_model(
|
|
model: Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]
|
|
) -> torch.nn.modules.Module:
|
|
"""
|
|
|
|
Args:
|
|
model: (:obj:`Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]`):
|
|
Model object used during training
|
|
|
|
Returns:
|
|
:obj:`torch.nn.modules.Module`: unwrapped module
|
|
"""
|
|
if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
|
model = model.module
|
|
else:
|
|
model = model
|
|
return model
|