Trainer callbacks (#7596)
* Initial callback proposal * Finish various callbacks * Post-rebase conflicts * Fix tests * Don't use something that's not set * Documentation * Remove unwanted print. * Document all models can work * Add tests + small fixes * Update docs/source/internal/trainer_utils.rst Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Address review comments * Fix TF tests * Real fix this time * This one should work * Fix typo * Really fix typo Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -213,6 +213,7 @@ conversion utilities for the following models:
|
||||
:maxdepth: 2
|
||||
:caption: Main Classes
|
||||
|
||||
main_classes/callback
|
||||
main_classes/configuration
|
||||
main_classes/logging
|
||||
main_classes/model
|
||||
@@ -270,3 +271,4 @@ conversion utilities for the following models:
|
||||
internal/modeling_utils
|
||||
internal/pipelines_utils
|
||||
internal/tokenization_utils
|
||||
internal/trainer_utils
|
||||
|
||||
21
docs/source/internal/trainer_utils.rst
Normal file
21
docs/source/internal/trainer_utils.rst
Normal file
@@ -0,0 +1,21 @@
|
||||
Utilities for Trainer
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
This page lists all the utility functions used by :class:`~transformers.Trainer`.
|
||||
|
||||
Most of those are only useful if you are studying the code of the Trainer in the library.
|
||||
|
||||
Utilities
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.EvalPrediction
|
||||
|
||||
.. autofunction:: transformers.set_seed
|
||||
|
||||
.. autofunction:: transformers.torch_distributed_zero_first
|
||||
|
||||
|
||||
Callbacks internals
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.trainer_callback.CallbackHandler
|
||||
68
docs/source/main_classes/callback.rst
Normal file
68
docs/source/main_classes/callback.rst
Normal file
@@ -0,0 +1,68 @@
|
||||
Callbacks
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Callbacks are objects that can customize the behavior of the training loop in the PyTorch
|
||||
:class:`~transformers.Trainer` (this feature is not yet implemented in TensorFlow) that can inspect the training loop
|
||||
state (for progress reporting, logging on TensorBoard or other ML platforms...) and take decisions (like early
|
||||
stopping).
|
||||
|
||||
Callbacks are "read only" pieces of code, apart from the :class:`~transformers.TrainerControl` object they return, they
|
||||
cannot change anything in the training loop. For customizations that require changes in the training loop, you should
|
||||
subclass :class:`~transformers.Trainer` and override the methods you need (see :doc:`trainer` for examples).
|
||||
|
||||
By default a :class:`~transformers.Trainer` will use the following callbacks:
|
||||
|
||||
- :class:`~transformers.DefaultFlowCallback` which handles the default beahvior for logging, saving and evaluation.
|
||||
- :class:`~transformers.PrinterCallback` or :class:`~transformers.ProrgressCallback` to display progress and print the
|
||||
logs (the first one is used if you deactivate tqdm through the :class:`~transformers.TrainingArguments`, otherwise
|
||||
it's the second one).
|
||||
- :class:`~transformers.integrations.TensorBoardCallback` if tensorboard is accessible (either through PyTorch >= 1.4
|
||||
or tensorboardX).
|
||||
- :class:`~transformers.integrations.WandbCallback` if `wandb <https://www.wandb.com/>`__ is installed.
|
||||
- :class:`~transformers.integrations.CometCallback` if `comet_ml <https://www.comet.ml/site/>`__ is installed.
|
||||
|
||||
The main class that implements callbacks is :class:`~transformers.TrainerCallback`. It gets the
|
||||
:class:`~transformers.TrainingArguments` used to instantiate the :class:`~transformers.Trainer`, can access that
|
||||
Trainer's internal state via :class:`~transformers.TrainerState`, and can take some actions on the training loop via
|
||||
:class:`~transformers.TrainerControl`.
|
||||
|
||||
|
||||
Available Callbacks
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Here is the list of the available :class:`~transformers.TrainerCallback` in the library:
|
||||
|
||||
.. autoclass:: transformers.integrations.CometCallback
|
||||
:members: setup
|
||||
|
||||
.. autoclass:: transformers.DefaultFlowCallback
|
||||
|
||||
.. autoclass:: transformers.PrinterCallback
|
||||
|
||||
.. autoclass:: transformers.ProgressCallback
|
||||
|
||||
.. autoclass:: transformers.integrations.TensorBoardCallback
|
||||
|
||||
.. autoclass:: transformers.integrations.WandbCallback
|
||||
:members: setup
|
||||
|
||||
|
||||
TrainerCallback
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TrainerCallback
|
||||
:members:
|
||||
|
||||
|
||||
TrainerState
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TrainerState
|
||||
:members:
|
||||
|
||||
|
||||
TrainerControl
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TrainerControl
|
||||
:members:
|
||||
@@ -18,7 +18,6 @@ previous features. To inject custom behavior you can subclass them and override
|
||||
- **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaulation DataLoader (PyTorch) or TF Dataset.
|
||||
- **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset.
|
||||
- **log** -- Logs information on the various objects watching training.
|
||||
- **setup_wandb** -- Setups wandb (see `here <https://docs.wandb.com/huggingface>`__ for more information).
|
||||
- **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at
|
||||
init.
|
||||
- **compute_loss** - Computes the loss on a batch of training inputs.
|
||||
@@ -40,6 +39,10 @@ Here is an example of how to customize :class:`~transformers.Trainer` using a cu
|
||||
logits = outputs[0]
|
||||
return my_custom_loss(logits, labels)
|
||||
|
||||
Another way to customize the training loop behavior for the PyTorch :class:`~transformers.Trainer` is to use
|
||||
:doc:`callbacks <callback>` that can inspect the training loop state (for progress reporting, logging on TensorBoard or
|
||||
other ML platforms...) and take decisions (like early stopping).
|
||||
|
||||
|
||||
Trainer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@@ -47,29 +50,23 @@ Trainer
|
||||
.. autoclass:: transformers.Trainer
|
||||
:members:
|
||||
|
||||
|
||||
TFTrainer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFTrainer
|
||||
:members:
|
||||
|
||||
|
||||
TrainingArguments
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TrainingArguments
|
||||
:members:
|
||||
|
||||
|
||||
TFTrainingArguments
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFTrainingArguments
|
||||
:members:
|
||||
|
||||
Utilities
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.EvalPrediction
|
||||
|
||||
.. autofunction:: transformers.set_seed
|
||||
|
||||
.. autofunction:: transformers.torch_distributed_zero_first
|
||||
|
||||
@@ -9,7 +9,7 @@ from transformers import Trainer
|
||||
from transformers.configuration_fsmt import FSMTConfig
|
||||
from transformers.file_utils import is_torch_tpu_available
|
||||
from transformers.optimization import Adafactor, AdamW, get_linear_schedule_with_warmup
|
||||
from transformers.trainer import get_tpu_sampler
|
||||
from transformers.trainer_pt_utils import get_tpu_sampler
|
||||
|
||||
|
||||
try:
|
||||
|
||||
@@ -4,7 +4,8 @@ import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers.testing_utils import slow
|
||||
from transformers.trainer_utils import TrainerState, set_seed
|
||||
from transformers.trainer_callback import TrainerState
|
||||
from transformers.trainer_utils import set_seed
|
||||
|
||||
from .finetune_trainer import main
|
||||
from .test_seq2seq_examples import MBART_TINY
|
||||
|
||||
@@ -205,7 +205,15 @@ from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
|
||||
|
||||
# Trainer
|
||||
from .trainer_utils import EvalPrediction, TrainerState, set_seed
|
||||
from .trainer_callback import (
|
||||
DefaultFlowCallback,
|
||||
PrinterCallback,
|
||||
ProgressCallback,
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
)
|
||||
from .trainer_utils import EvalPrediction, EvaluationStrategy, set_seed
|
||||
from .training_args import TrainingArguments
|
||||
from .training_args_tf import TFTrainingArguments
|
||||
from .utils import logging
|
||||
@@ -529,7 +537,8 @@ if is_torch_available():
|
||||
from .tokenization_marian import MarianTokenizer
|
||||
|
||||
# Trainer
|
||||
from .trainer import EvalPrediction, Trainer, set_seed, torch_distributed_zero_first
|
||||
from .trainer import Trainer
|
||||
from .trainer_pt_utils import torch_distributed_zero_first
|
||||
else:
|
||||
from .utils.dummy_pt_objects import *
|
||||
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
import math
|
||||
import os
|
||||
|
||||
from .file_utils import is_torch_tpu_available
|
||||
from .trainer_callback import TrainerCallback
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun
|
||||
from .utils import logging
|
||||
|
||||
|
||||
try:
|
||||
import comet_ml # noqa: F401
|
||||
@@ -36,15 +41,6 @@ try:
|
||||
except (ImportError):
|
||||
_has_ray = False
|
||||
|
||||
|
||||
# No ML framework or transformer imports above this point
|
||||
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun # isort:skip
|
||||
from .utils import logging # isort:skip
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter # noqa: F401
|
||||
|
||||
@@ -57,9 +53,10 @@ except ImportError:
|
||||
except ImportError:
|
||||
_has_tensorboard = False
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Integration functions:
|
||||
|
||||
|
||||
def is_wandb_available():
|
||||
return _has_wandb
|
||||
|
||||
@@ -128,8 +125,8 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
|
||||
|
||||
# The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
|
||||
# while doing the ray hp search.
|
||||
_tb_writer = trainer.tb_writer
|
||||
trainer.tb_writer = None
|
||||
|
||||
_tb_writer = trainer.pop_callback(TensorBoardCallback)
|
||||
trainer.model = None
|
||||
# Setup default `resources_per_trial` and `reporter`.
|
||||
if "resources_per_trial" not in kwargs and trainer.args.n_gpu > 0:
|
||||
@@ -182,5 +179,159 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
|
||||
analysis = ray.tune.run(_objective, config=trainer.hp_space(None), num_samples=n_trials, **kwargs)
|
||||
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
|
||||
best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
|
||||
trainer.tb_writer = _tb_writer
|
||||
if _tb_writer is not None:
|
||||
trainer.add_callback(_tb_writer)
|
||||
return best_run
|
||||
|
||||
|
||||
class TensorBoardCallback(TrainerCallback):
|
||||
"""
|
||||
A :class:`~transformers.TrainerCallback` that sends the logs to `TensorBoard
|
||||
<https://www.tensorflow.org/tensorboard>`__.
|
||||
|
||||
Args:
|
||||
tb_writer (:obj:`SummaryWriter`, `optional`):
|
||||
The writer to use. Will instatiate one if not set.
|
||||
"""
|
||||
|
||||
def __init__(self, tb_writer=None):
|
||||
assert (
|
||||
_has_tensorboard
|
||||
), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
|
||||
self.tb_writer = tb_writer
|
||||
|
||||
def on_init_end(self, args, state, control, **kwargs):
|
||||
if self.tb_writer is None and state.is_world_process_zero:
|
||||
self.tb_writer = SummaryWriter(log_dir=args.logging_dir)
|
||||
|
||||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
if self.tb_writer is not None:
|
||||
self.tb_writer.add_text("args", args.to_json_string())
|
||||
self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if self.tb_writer:
|
||||
for k, v in logs.items():
|
||||
if isinstance(v, (int, float)):
|
||||
self.tb_writer.add_scalar(k, v, state.global_step)
|
||||
else:
|
||||
logger.warning(
|
||||
"Trainer is attempting to log a value of "
|
||||
'"%s" of type %s for key "%s" as a scalar. '
|
||||
"This invocation of Tensorboard's writer.add_scalar() "
|
||||
"is incorrect so we dropped this attribute.",
|
||||
v,
|
||||
type(v),
|
||||
k,
|
||||
)
|
||||
self.tb_writer.flush()
|
||||
|
||||
def on_train_end(self, args, state, control, **kwargs):
|
||||
if self.tb_writer:
|
||||
self.tb_writer.close()
|
||||
|
||||
|
||||
class WandbCallback(TrainerCallback):
|
||||
"""
|
||||
A :class:`~transformers.TrainerCallback` that sends the logs to `Weight and Biases
|
||||
<https://www.wandb.com/>`__.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
assert _has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`."
|
||||
self._initialized = False
|
||||
|
||||
def setup(self, args, state, model):
|
||||
"""
|
||||
Setup the optional Weights & Biases (`wandb`) integration.
|
||||
|
||||
One can subclass and override this method to customize the setup if needed. Find more information
|
||||
`here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
|
||||
|
||||
Environment:
|
||||
WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`):
|
||||
Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient
|
||||
logging or :obj:`"all"` to log gradients and parameters.
|
||||
WANDB_PROJECT (:obj:`str`, `optional`, defaults to :obj:`"huggingface"`):
|
||||
Set this to a custom string to store results in a different project.
|
||||
WANDB_DISABLED (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to disable wandb entirely.
|
||||
"""
|
||||
self._initialized = True
|
||||
if state.is_world_process_zero:
|
||||
logger.info(
|
||||
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
|
||||
)
|
||||
combined_dict = {**args.to_sanitized_dict()}
|
||||
if hasattr(model, "config"):
|
||||
combined_dict = {**model.config.to_dict(), **combined_dict}
|
||||
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=args.run_name)
|
||||
# keep track of model topology and gradients, unsupported on TPU
|
||||
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
|
||||
wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps))
|
||||
|
||||
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||
if not self._initialized:
|
||||
self.setup(args, state, model)
|
||||
|
||||
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
|
||||
if not self._initialized:
|
||||
self.setup(args, state, model)
|
||||
if state.is_world_process_zero:
|
||||
wandb.log(logs, step=state.global_step)
|
||||
|
||||
|
||||
class CometCallback(TrainerCallback):
|
||||
"""
|
||||
A :class:`~transformers.TrainerCallback` that sends the logs to `Comet ML
|
||||
<https://www.comet.ml/site/>`__.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
assert _has_comet, "CometCallback requires comet-ml to be installed. Run `pip install comet-ml`."
|
||||
self._initialized = False
|
||||
|
||||
def setup(self, args, state, model):
|
||||
"""
|
||||
Setup the optional Comet.ml integration.
|
||||
|
||||
Environment:
|
||||
COMET_MODE (:obj:`str`, `optional`):
|
||||
"OFFLINE", "ONLINE", or "DISABLED"
|
||||
COMET_PROJECT_NAME (:obj:`str`, `optional`):
|
||||
Comet.ml project name for experiments
|
||||
COMET_OFFLINE_DIRECTORY (:obj:`str`, `optional`):
|
||||
Folder to use for saving offline experiments when :obj:`COMET_MODE` is "OFFLINE"
|
||||
|
||||
For a number of configurable items in the environment,
|
||||
see `here <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__.
|
||||
"""
|
||||
self._initialized = True
|
||||
if state.is_world_process_zero:
|
||||
comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
|
||||
args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
|
||||
experiment = None
|
||||
if comet_mode == "ONLINE":
|
||||
experiment = comet_ml.Experiment(**args)
|
||||
logger.info("Automatic Comet.ml online logging enabled")
|
||||
elif comet_mode == "OFFLINE":
|
||||
args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
|
||||
experiment = comet_ml.OfflineExperiment(**args)
|
||||
logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
|
||||
if experiment is not None:
|
||||
experiment._set_model_graph(model, framework="transformers")
|
||||
experiment._log_parameters(args, prefix="args/", framework="transformers")
|
||||
if hasattr(model, "config"):
|
||||
experiment._log_parameters(model.config, prefix="config/", framework="transformers")
|
||||
|
||||
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||
if not self._initialized:
|
||||
self.setup(args, state, model)
|
||||
|
||||
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
|
||||
if not self._initialized:
|
||||
self.setup(args, state, model)
|
||||
if state.is_world_process_zero:
|
||||
experiment = comet_ml.config.get_global_experiment()
|
||||
if experiment is not None:
|
||||
experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")
|
||||
|
||||
@@ -1,10 +1,26 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
@@ -15,8 +31,7 @@ from torch import nn
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
|
||||
from tqdm.auto import tqdm, trange
|
||||
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
||||
|
||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_torch_tpu_available
|
||||
@@ -34,23 +49,35 @@ from .modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .optimization import AdamW, get_linear_schedule_with_warmup
|
||||
from .tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from .trainer_utils import (
|
||||
PREFIX_CHECKPOINT_DIR,
|
||||
BestRun,
|
||||
EvalPrediction,
|
||||
EvaluationStrategy,
|
||||
HPSearchBackend,
|
||||
PredictionOutput,
|
||||
from .trainer_callback import (
|
||||
CallbackHandler,
|
||||
DefaultFlowCallback,
|
||||
PrinterCallback,
|
||||
ProgressCallback,
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainOutput,
|
||||
default_compute_objective,
|
||||
default_hp_space,
|
||||
)
|
||||
from .trainer_pt_utils import (
|
||||
SequentialDistributedSampler,
|
||||
distributed_broadcast_scalars,
|
||||
distributed_concat,
|
||||
get_tpu_sampler,
|
||||
nested_concat,
|
||||
nested_detach,
|
||||
nested_numpify,
|
||||
nested_xla_mesh_reduce,
|
||||
reissue_pt_warnings,
|
||||
)
|
||||
from .trainer_utils import (
|
||||
PREFIX_CHECKPOINT_DIR,
|
||||
BestRun,
|
||||
EvalPrediction,
|
||||
HPSearchBackend,
|
||||
PredictionOutput,
|
||||
TrainOutput,
|
||||
default_compute_objective,
|
||||
default_hp_space,
|
||||
set_seed,
|
||||
)
|
||||
from .training_args import TrainingArguments
|
||||
@@ -60,7 +87,8 @@ from .utils import logging
|
||||
_use_native_amp = False
|
||||
_use_apex = False
|
||||
|
||||
PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."
|
||||
DEFAULT_CALLBACKS = [DefaultFlowCallback]
|
||||
|
||||
|
||||
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
|
||||
if version.parse(torch.__version__) < version.parse("1.6"):
|
||||
@@ -82,16 +110,20 @@ if is_torch_tpu_available():
|
||||
import torch_xla.distributed.parallel_loader as pl
|
||||
|
||||
if is_tensorboard_available():
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except ImportError:
|
||||
from tensorboardX import SummaryWriter
|
||||
from .integrations import TensorBoardCallback
|
||||
|
||||
DEFAULT_CALLBACKS.append(TensorBoardCallback)
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
from .integrations import WandbCallback
|
||||
|
||||
DEFAULT_CALLBACKS.append(WandbCallback)
|
||||
|
||||
if is_comet_available():
|
||||
import comet_ml
|
||||
from .integrations import CometCallback
|
||||
|
||||
DEFAULT_CALLBACKS.append(CometCallback)
|
||||
|
||||
if is_optuna_available():
|
||||
import optuna
|
||||
@@ -102,91 +134,20 @@ if is_ray_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def reissue_pt_warnings(caught_warnings):
|
||||
# Reissue warnings that are not the PT_LR_SCHEDULER_WARNING
|
||||
if len(caught_warnings) > 1:
|
||||
for w in caught_warnings:
|
||||
if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING:
|
||||
warnings.warn(w.message, w.category)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def torch_distributed_zero_first(local_rank: int):
|
||||
"""
|
||||
Decorator to make all processes in distributed training wait for each local_master to do something.
|
||||
|
||||
Args:
|
||||
local_rank (:obj:`int`): The rank of the local process.
|
||||
"""
|
||||
if local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier()
|
||||
yield
|
||||
if local_rank == 0:
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
class SequentialDistributedSampler(Sampler):
|
||||
"""
|
||||
Distributed Sampler that subsamples indicies sequentially,
|
||||
making it easier to collate all results at the end.
|
||||
|
||||
Even though we only use this sampler for eval and predict (no training),
|
||||
which means that the model params won't have to be synced (i.e. will not hang
|
||||
for synchronization even if varied number of forward passes), we still add extra
|
||||
samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
|
||||
to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, num_replicas=None, rank=None):
|
||||
if num_replicas is None:
|
||||
if not torch.distributed.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = torch.distributed.get_world_size()
|
||||
if rank is None:
|
||||
if not torch.distributed.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = torch.distributed.get_rank()
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
|
||||
def __iter__(self):
|
||||
indices = list(range(len(self.dataset)))
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[: (self.total_size - len(indices))]
|
||||
assert (
|
||||
len(indices) == self.total_size
|
||||
), f"Indices length {len(indices)} and total size {self.total_size} mismatched"
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
|
||||
assert (
|
||||
len(indices) == self.num_samples
|
||||
), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
|
||||
def get_tpu_sampler(dataset: Dataset):
|
||||
if xm.xrt_world_size() <= 1:
|
||||
return RandomSampler(dataset)
|
||||
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
Trainer is a simple but feature-complete training and eval loop for PyTorch,
|
||||
optimized for 🤗 Transformers.
|
||||
|
||||
Args:
|
||||
model (:class:`~transformers.PreTrainedModel`, `optional`):
|
||||
model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`, `optional`):
|
||||
The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
|
||||
|
||||
.. note::
|
||||
|
||||
:class:`~transformers.Trainer` is optimized to work with the :class:`~transformers.PreTrainedModel`
|
||||
provided by the library. You can still use your own models defined as :obj:`torch.nn.Module` as long as
|
||||
they work the same way as the 🤗 Transformers models.
|
||||
args (:class:`~transformers.TrainingArguments`, `optional`):
|
||||
The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
|
||||
with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
|
||||
@@ -210,8 +171,11 @@ class Trainer:
|
||||
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
|
||||
The function that will be used to compute metrics at evaluation. Must take a
|
||||
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
|
||||
tb_writer (:obj:`SummaryWriter`, `optional`):
|
||||
Object to write to TensorBoard.
|
||||
callbacks (List of :obj:`~transformers.TrainerCallback`, `optional`):
|
||||
A list of callbacks to customize the training loop. Will add those to the list of default callbacks
|
||||
detailed in :doc:`here <callback>`.
|
||||
|
||||
If you want to remove one of the default callbacks used, use the :meth:`Trainer.remove_callback` method.
|
||||
optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`):
|
||||
A tuple containing the optimizer and the scheduler to use. Will default to an instance of
|
||||
:class:`~transformers.AdamW` on your model and a scheduler given by
|
||||
@@ -222,7 +186,7 @@ class Trainer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: PreTrainedModel = None,
|
||||
model: Union[PreTrainedModel, torch.nn.Module] = None,
|
||||
args: TrainingArguments = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
@@ -230,7 +194,7 @@ class Trainer:
|
||||
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
||||
model_init: Callable[[], PreTrainedModel] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
tb_writer: Optional["SummaryWriter"] = None,
|
||||
callbacks: Optional[List[TrainerCallback]] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
**kwargs,
|
||||
):
|
||||
@@ -259,7 +223,21 @@ class Trainer:
|
||||
"Passing a `model_init` is incompatible with providing the `optimizers` argument."
|
||||
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
||||
)
|
||||
self.tb_writer = tb_writer
|
||||
callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
|
||||
self.callback_handler = CallbackHandler(callbacks, self.model, self.optimizer, self.lr_scheduler)
|
||||
self.add_callback(PrinterCallback if self.args.disable_tqdm else ProgressCallback)
|
||||
|
||||
# Deprecated arguments
|
||||
if "tb_writer" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `tb_writer` as a keyword argument is deprecated and won't be possible in a "
|
||||
+ "future version. Use `TensorBoardCallback(tb_writer=...)` instead and pass it to the `callbacks`"
|
||||
+ "argument",
|
||||
FutureWarning,
|
||||
)
|
||||
tb_writer = kwargs.pop("tb_writer")
|
||||
self.remove_callback(TensorBoardCallback)
|
||||
self.add_callback(TensorBoardCallback(tb_writer=tb_writer))
|
||||
if "prediction_loss_only" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a "
|
||||
@@ -270,13 +248,6 @@ class Trainer:
|
||||
self.args.prediction_loss_only = kwargs.pop("prediction_loss_only")
|
||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||
|
||||
if tb_writer is None and is_tensorboard_available() and self.is_world_process_zero():
|
||||
self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
|
||||
if not is_tensorboard_available():
|
||||
logger.warning(
|
||||
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
|
||||
)
|
||||
|
||||
# Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
|
||||
self._loggers_initialized = False
|
||||
|
||||
@@ -304,6 +275,7 @@ class Trainer:
|
||||
self._remove_unused_columns(self.eval_dataset, description="evaluation")
|
||||
|
||||
self.state = TrainerState()
|
||||
self.control = TrainerControl()
|
||||
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
|
||||
# state at each call to self.log.
|
||||
self._total_flos = None
|
||||
@@ -317,6 +289,45 @@ class Trainer:
|
||||
else ["labels"]
|
||||
)
|
||||
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
|
||||
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
|
||||
|
||||
def add_callback(self, callback):
|
||||
"""
|
||||
Add a callback to the current list of :class:`~transformer.TrainerCallback`.
|
||||
|
||||
Args:
|
||||
callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
|
||||
A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
|
||||
In the first case, will instantiate a member of that class.
|
||||
"""
|
||||
self.callback_handler.add_callback(callback)
|
||||
|
||||
def pop_callback(self, callback):
|
||||
"""
|
||||
Remove a callback from the current list of :class:`~transformer.TrainerCallback` and returns it.
|
||||
|
||||
If the callback is not found, returns :obj:`None` (and no error is raised).
|
||||
|
||||
Args:
|
||||
callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
|
||||
A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
|
||||
In the first case, will pop the first member of that class found in the list of callbacks.
|
||||
|
||||
Returns:
|
||||
:class:`~transformer.TrainerCallback`: The callback removed, if found.
|
||||
"""
|
||||
return self.callback_handler.pop_callback(callback)
|
||||
|
||||
def remove_callback(self, callback):
|
||||
"""
|
||||
Remove a callback from the current list of :class:`~transformer.TrainerCallback`.
|
||||
|
||||
Args:
|
||||
callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
|
||||
A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
|
||||
In the first case, will remove the first member of that class found in the list of callbacks.
|
||||
"""
|
||||
self.callback_handler.remove_callback(callback)
|
||||
|
||||
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
|
||||
if not self.args.remove_unused_columns:
|
||||
@@ -465,102 +476,12 @@ class Trainer:
|
||||
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
|
||||
)
|
||||
|
||||
def setup_wandb(self):
|
||||
"""
|
||||
Setup the optional Weights & Biases (`wandb`) integration.
|
||||
|
||||
One can subclass and override this method to customize the setup if needed. Find more information
|
||||
`here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
|
||||
|
||||
Environment:
|
||||
WANDB_WATCH:
|
||||
(Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
|
||||
or "all" to log gradients and parameters
|
||||
WANDB_PROJECT:
|
||||
(Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
|
||||
WANDB_DISABLED:
|
||||
(Optional): boolean - defaults to false, set to "true" to disable wandb entirely
|
||||
"""
|
||||
if hasattr(self, "_setup_wandb"):
|
||||
warnings.warn(
|
||||
"The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
|
||||
FutureWarning,
|
||||
)
|
||||
return self._setup_wandb()
|
||||
|
||||
if self.is_world_process_zero():
|
||||
logger.info(
|
||||
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
|
||||
)
|
||||
combined_dict = {**self.args.to_sanitized_dict()}
|
||||
if isinstance(self.model, PreTrainedModel):
|
||||
combined_dict = {**self.model.config.to_dict(), **combined_dict}
|
||||
wandb.init(
|
||||
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
|
||||
)
|
||||
# keep track of model topology and gradients, unsupported on TPU
|
||||
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
|
||||
wandb.watch(
|
||||
self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
|
||||
)
|
||||
|
||||
def setup_comet(self):
|
||||
"""
|
||||
Setup the optional Comet.ml integration.
|
||||
|
||||
Environment:
|
||||
COMET_MODE:
|
||||
(Optional): str - "OFFLINE", "ONLINE", or "DISABLED"
|
||||
COMET_PROJECT_NAME:
|
||||
(Optional): str - Comet.ml project name for experiments
|
||||
COMET_OFFLINE_DIRECTORY:
|
||||
(Optional): str - folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE"
|
||||
|
||||
For a number of configurable items in the environment,
|
||||
see `here <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__
|
||||
"""
|
||||
if self.is_world_master():
|
||||
comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
|
||||
args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
|
||||
experiment = None
|
||||
if comet_mode == "ONLINE":
|
||||
experiment = comet_ml.Experiment(**args)
|
||||
logger.info("Automatic Comet.ml online logging enabled")
|
||||
elif comet_mode == "OFFLINE":
|
||||
args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
|
||||
experiment = comet_ml.OfflineExperiment(**args)
|
||||
logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
|
||||
if experiment is not None:
|
||||
experiment._set_model_graph(self.model, framework="transformers")
|
||||
experiment._log_parameters(self.args, prefix="args/", framework="transformers")
|
||||
if isinstance(self.model, PreTrainedModel):
|
||||
experiment._log_parameters(self.model.config, prefix="config/", framework="transformers")
|
||||
|
||||
def num_examples(self, dataloader: DataLoader) -> int:
|
||||
"""
|
||||
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
|
||||
"""
|
||||
return len(dataloader.dataset)
|
||||
|
||||
def _setup_loggers(self):
|
||||
if self._loggers_initialized:
|
||||
return
|
||||
if is_wandb_available():
|
||||
self.setup_wandb()
|
||||
elif os.environ.get("WANDB_DISABLED") != "true":
|
||||
logger.info(
|
||||
"You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
|
||||
"run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
|
||||
)
|
||||
if is_comet_available():
|
||||
self.setup_comet()
|
||||
elif os.environ.get("COMET_MODE") != "DISABLED":
|
||||
logger.info(
|
||||
"To use comet_ml logging, run `pip/conda install comet_ml` "
|
||||
"see https://www.comet.ml/docs/python-sdk/huggingface/"
|
||||
)
|
||||
self._loggers_initialized = True
|
||||
|
||||
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
|
||||
""" HP search setup code """
|
||||
if self.hp_search_backend is None or trial is None:
|
||||
@@ -661,7 +582,7 @@ class Trainer:
|
||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
||||
# Moxed precision training with apex (torch < 1.6)
|
||||
# Mixed precision training with apex (torch < 1.6)
|
||||
model = self.model
|
||||
if self.args.fp16 and _use_apex:
|
||||
if not is_apex_available():
|
||||
@@ -687,10 +608,6 @@ class Trainer:
|
||||
# find_unused_parameters breaks checkpointing as per
|
||||
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
|
||||
|
||||
if self.tb_writer is not None:
|
||||
self.tb_writer.add_text("args", self.args.to_json_string())
|
||||
self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
|
||||
|
||||
# Train!
|
||||
if is_torch_tpu_available():
|
||||
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
|
||||
@@ -723,17 +640,25 @@ class Trainer:
|
||||
logger.info(" Continuing training from global step %d", self.state.global_step)
|
||||
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
||||
|
||||
# Update the references
|
||||
self.callback_handler.model = self.model
|
||||
self.callback_handler.optimizer = self.optimizer
|
||||
self.callback_handler.lr_scheduler = self.lr_scheduler
|
||||
self.callback_handler.train_dataloader = train_dataloader
|
||||
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
|
||||
# to set this after the load.
|
||||
self.state.max_steps = max_steps
|
||||
self.state.num_train_epochs = num_train_epochs
|
||||
self.state.is_local_process_zero = self.is_local_process_zero()
|
||||
self.state.is_world_process_zero = self.is_world_process_zero()
|
||||
|
||||
tr_loss = torch.tensor(0.0).to(self.args.device)
|
||||
self._logging_loss_scalar = 0
|
||||
self._total_flos = self.state.total_flos
|
||||
logging_loss_scalar = 0.0
|
||||
model.zero_grad()
|
||||
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
|
||||
train_pbar = trange(epochs_trained, num_train_epochs, desc="Epoch", disable=disable_tqdm)
|
||||
|
||||
self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control)
|
||||
|
||||
for epoch in range(epochs_trained, num_train_epochs):
|
||||
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
||||
train_dataloader.sampler.set_epoch(epoch)
|
||||
@@ -750,15 +675,18 @@ class Trainer:
|
||||
if self.args.past_index >= 0:
|
||||
self._past = None
|
||||
|
||||
epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm)
|
||||
self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)
|
||||
|
||||
for step, inputs in enumerate(epoch_iterator):
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
epoch_pbar.update(1)
|
||||
continue
|
||||
|
||||
if (step + 1) % self.args.gradient_accumulation_steps == 0:
|
||||
self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)
|
||||
|
||||
tr_loss += self.training_step(model, inputs)
|
||||
self._total_flos += self.floating_point_ops(inputs)
|
||||
|
||||
@@ -787,50 +715,15 @@ class Trainer:
|
||||
model.zero_grad()
|
||||
self.state.global_step += 1
|
||||
self.state.epoch = epoch + (step + 1) / len(epoch_iterator)
|
||||
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
|
||||
|
||||
if (self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0) or (
|
||||
self.state.global_step == 1 and self.args.logging_first_step
|
||||
):
|
||||
logs: Dict[str, float] = {}
|
||||
tr_loss_scalar = tr_loss.item()
|
||||
logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps
|
||||
# backward compatibility for pytorch schedulers
|
||||
logs["learning_rate"] = (
|
||||
self.lr_scheduler.get_last_lr()[0]
|
||||
if version.parse(torch.__version__) >= version.parse("1.4")
|
||||
else self.lr_scheduler.get_lr()[0]
|
||||
)
|
||||
logging_loss_scalar = tr_loss_scalar
|
||||
self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
|
||||
|
||||
self.log(logs)
|
||||
|
||||
if (
|
||||
self.args.evaluation_strategy == EvaluationStrategy.STEPS
|
||||
and self.state.global_step % self.args.eval_steps == 0
|
||||
):
|
||||
metrics = self.evaluate()
|
||||
self._report_to_hp_search(trial, epoch, metrics)
|
||||
if self.args.load_best_model_at_end:
|
||||
self._save_training(model, trial, metrics=metrics)
|
||||
|
||||
if (
|
||||
not self.args.load_best_model_at_end
|
||||
and self.args.save_steps > 0
|
||||
and self.state.global_step % self.args.save_steps == 0
|
||||
):
|
||||
self._save_training(model, trial)
|
||||
|
||||
epoch_pbar.update(1)
|
||||
if self.state.global_step >= max_steps:
|
||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||
break
|
||||
epoch_pbar.close()
|
||||
train_pbar.update(1)
|
||||
|
||||
if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
|
||||
metrics = self.evaluate()
|
||||
self._report_to_hp_search(trial, epoch, metrics)
|
||||
if self.args.load_best_model_at_end:
|
||||
self._save_training(model, trial, metrics=metrics)
|
||||
self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
|
||||
self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
|
||||
|
||||
if self.args.tpu_metrics_debug or self.args.debug:
|
||||
if is_torch_tpu_available():
|
||||
@@ -841,12 +734,9 @@ class Trainer:
|
||||
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
|
||||
"configured. Check your training configuration if this is unexpected."
|
||||
)
|
||||
if self.state.global_step >= max_steps:
|
||||
if self.control.should_training_stop:
|
||||
break
|
||||
|
||||
train_pbar.close()
|
||||
if self.tb_writer:
|
||||
self.tb_writer.close()
|
||||
if self.args.past_index and hasattr(self, "_past"):
|
||||
# Clean the state at the end of training
|
||||
delattr(self, "_past")
|
||||
@@ -863,9 +753,36 @@ class Trainer:
|
||||
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
||||
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
|
||||
|
||||
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
|
||||
|
||||
def _save_training(self, model, trial, metrics=None):
|
||||
def _maybe_log_save_evalute(self, tr_loss, model, trial, epoch):
|
||||
if self.control.should_log:
|
||||
logs: Dict[str, float] = {}
|
||||
tr_loss_scalar = tr_loss.item()
|
||||
logs["loss"] = (tr_loss_scalar - self._logging_loss_scalar) / self.args.logging_steps
|
||||
# backward compatibility for pytorch schedulers
|
||||
logs["learning_rate"] = (
|
||||
self.lr_scheduler.get_last_lr()[0]
|
||||
if version.parse(torch.__version__) >= version.parse("1.4")
|
||||
else self.lr_scheduler.get_lr()[0]
|
||||
)
|
||||
self._logging_loss_scalar = tr_loss_scalar
|
||||
|
||||
self.log(logs)
|
||||
|
||||
metrics = None
|
||||
if self.control.should_evaluate:
|
||||
metrics = self.evaluate()
|
||||
self._report_to_hp_search(trial, epoch, metrics)
|
||||
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
||||
|
||||
if self.control.should_save:
|
||||
self._save_checkpoint(model, trial, metrics=metrics)
|
||||
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
||||
|
||||
def _save_checkpoint(self, model, trial, metrics=None):
|
||||
# In all cases (even distributed/parallel), self.model is always a reference
|
||||
# to the model we want to save.
|
||||
if hasattr(model, "module"):
|
||||
@@ -896,7 +813,7 @@ class Trainer:
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
||||
# Determine the new best metric / best model checkpoint
|
||||
if metrics is not None:
|
||||
if metrics is not None and self.args.metric_for_best_model is not None:
|
||||
metric_to_check = self.args.metric_for_best_model
|
||||
if not metric_to_check.startswith("eval_"):
|
||||
metric_to_check = f"eval_{metric_to_check}"
|
||||
@@ -998,7 +915,7 @@ class Trainer:
|
||||
self.hp_search_backend = None
|
||||
return best_run
|
||||
|
||||
def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
|
||||
def log(self, logs: Dict[str, float]) -> None:
|
||||
"""
|
||||
Log :obj:`logs` on the various objects watching training.
|
||||
|
||||
@@ -1007,55 +924,22 @@ class Trainer:
|
||||
Args:
|
||||
logs (:obj:`Dict[str, float]`):
|
||||
The values to log.
|
||||
iterator (:obj:`tqdm`, `optional`):
|
||||
A potential tqdm progress bar to write the logs on.
|
||||
"""
|
||||
# Set up loggers like W&B or Comet ML
|
||||
self._setup_loggers()
|
||||
|
||||
if hasattr(self, "_log"):
|
||||
warnings.warn(
|
||||
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
|
||||
FutureWarning,
|
||||
)
|
||||
return self._log(logs, iterator=iterator)
|
||||
return self._log(logs)
|
||||
|
||||
if self.state.epoch is not None:
|
||||
logs["epoch"] = self.state.epoch
|
||||
if self._total_flos is not None:
|
||||
self.store_flos()
|
||||
logs["total_flos"] = self.state.total_flos
|
||||
if self.tb_writer:
|
||||
for k, v in logs.items():
|
||||
if isinstance(v, (int, float)):
|
||||
self.tb_writer.add_scalar(k, v, self.state.global_step)
|
||||
else:
|
||||
logger.warning(
|
||||
"Trainer is attempting to log a value of "
|
||||
'"%s" of type %s for key "%s" as a scalar. '
|
||||
"This invocation of Tensorboard's writer.add_scalar() "
|
||||
"is incorrect so we dropped this attribute.",
|
||||
v,
|
||||
type(v),
|
||||
k,
|
||||
)
|
||||
self.tb_writer.flush()
|
||||
if is_wandb_available():
|
||||
if self.is_world_process_zero():
|
||||
wandb.log(logs, step=self.state.global_step)
|
||||
if is_comet_available():
|
||||
if self.is_world_process_zero():
|
||||
experiment = comet_ml.config.get_global_experiment()
|
||||
if experiment is not None:
|
||||
experiment._log_metrics(
|
||||
logs, step=self.state.global_step, epoch=self.state.epoch, framework="transformers"
|
||||
)
|
||||
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
|
||||
output = {**logs, **{"step": self.state.global_step}}
|
||||
self.state.log_history.append(output)
|
||||
if iterator is not None:
|
||||
iterator.write(output)
|
||||
else:
|
||||
print(output)
|
||||
|
||||
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
|
||||
"""
|
||||
@@ -1372,8 +1256,9 @@ class Trainer:
|
||||
if self.args.past_index >= 0:
|
||||
self._past = None
|
||||
|
||||
disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
|
||||
for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
|
||||
self.callback_handler.eval_dataloader = dataloader
|
||||
|
||||
for inputs in dataloader:
|
||||
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
|
||||
batch_size = inputs[list(inputs.keys())[0]].shape[0]
|
||||
if loss is not None:
|
||||
@@ -1382,6 +1267,7 @@ class Trainer:
|
||||
preds = logits if preds is None else nested_concat(preds, logits, dim=0)
|
||||
if labels is not None:
|
||||
label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0)
|
||||
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
|
||||
|
||||
if self.args.past_index and hasattr(self, "_past"):
|
||||
# Clean the state at the end of the evaluation loop
|
||||
|
||||
468
src/transformers/trainer_callback.py
Normal file
468
src/transformers/trainer_callback.py
Normal file
@@ -0,0 +1,468 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Callbacks to use with the Trainer class and customize the training loop.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from .trainer_utils import EvaluationStrategy
|
||||
from .training_args import TrainingArguments
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainerState:
|
||||
"""
|
||||
A class containing the :class:`~transformers.Trainer` inner state that will be saved along the model and optimizer
|
||||
when checkpointing and passed to the :class:`~transformers.TrainerCallback`.
|
||||
|
||||
.. note::
|
||||
|
||||
In all this class, one step is to be understood as one update step. When using gradient accumulation, one
|
||||
update step may require several forward and backward passes: if you use :obj:`gradient_accumulation_steps=n`,
|
||||
then one update step requires going throuch `n` batches.
|
||||
|
||||
Args:
|
||||
epoch (:obj:`float`, `optional`):
|
||||
Only set during training, will represent the epoch the training is at (the decimal part being the
|
||||
percentage of the current epoch completed).
|
||||
global_step (:obj:`int`, `optional`, defaults to 0):
|
||||
During training, represents the number of update steps completed.
|
||||
max_steps (:obj:`int`, `optional`, defaults to 0):
|
||||
The number of update steps to do during the current training.
|
||||
total_flos (:obj:`int`, `optional`, defaults to 0):
|
||||
The total number of floating operations done by the model since the beginning of training.
|
||||
log_history (:obj:`List[Dict[str, float]]`, `optional`):
|
||||
The list of logs done since the beginning of training.
|
||||
best_metric (:obj:`float`, `optional`):
|
||||
When tracking the best model, the value of the best metric encountered so far.
|
||||
best_model_checkpoint (:obj:`str`, `optional`):
|
||||
When tracking the best model, the value of the name of the checkpoint for the best model encountered so
|
||||
far.
|
||||
is_local_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
|
||||
several machines) main process.
|
||||
is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not this process is the global main process (when training in a distributed fashion on
|
||||
several machines, this is only going to be :obj:`True` for one process).
|
||||
"""
|
||||
|
||||
epoch: Optional[float] = None
|
||||
global_step: int = 0
|
||||
max_steps: int = 0
|
||||
num_train_epochs: int = 0
|
||||
total_flos: int = 0
|
||||
log_history: List[Dict[str, float]] = None
|
||||
best_metric: Optional[float] = None
|
||||
best_model_checkpoint: Optional[str] = None
|
||||
is_local_process_zero: bool = True
|
||||
is_world_process_zero: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.log_history is None:
|
||||
self.log_history = []
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
""" Save the content of this instance in JSON format inside :obj:`json_path`."""
|
||||
json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
f.write(json_string)
|
||||
|
||||
@classmethod
|
||||
def load_from_json(cls, json_path: str):
|
||||
""" Create an instance from the content of :obj:`json_path`."""
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
return cls(**json.loads(text))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainerControl:
|
||||
"""
|
||||
A class that handles the :class:`~transformers.Trainer` control flow. This class is used by the
|
||||
:class:`~transformers.TrainerCallback` to activate some switches in the training loop.
|
||||
|
||||
Args:
|
||||
should_training_stop (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not the training should be interrupted.
|
||||
|
||||
If :obj:`True`, this variable will not be set back to :obj:`False`. The training will just stop.
|
||||
should_epoch_stop (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not the current epoch should be interrupted.
|
||||
|
||||
If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next epoch.
|
||||
should_save (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not the model should be saved at this step.
|
||||
|
||||
If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next step.
|
||||
should_evaluate (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not the model should be evaluated at this step.
|
||||
|
||||
If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next step.
|
||||
should_log (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not the logs should be reported at this step.
|
||||
|
||||
If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next step.
|
||||
"""
|
||||
|
||||
should_training_stop: bool = False
|
||||
should_epoch_stop: bool = False
|
||||
should_save: bool = False
|
||||
should_evaluate: bool = False
|
||||
should_log: bool = False
|
||||
|
||||
def _new_training(self):
|
||||
""" Internal method that resets the variable for a new training. """
|
||||
self.should_training_stop = False
|
||||
|
||||
def _new_epoch(self):
|
||||
""" Internal method that resets the variable for a new epoch. """
|
||||
self.should_epoch_stop = False
|
||||
|
||||
def _new_step(self):
|
||||
""" Internal method that resets the variable for a new step. """
|
||||
self.should_save_model = False
|
||||
self.should_evaluate = False
|
||||
self.should_log = False
|
||||
|
||||
|
||||
class TrainerCallback:
|
||||
"""
|
||||
A class for objects that will inspect the state of the training loop at some events and take some decisions. At
|
||||
each of those events the following arguments are available:
|
||||
|
||||
Args:
|
||||
args (:class:`~transformers.TrainingArguments`):
|
||||
The training arguments used to instantiate the :class:`~transformers.Trainer`.
|
||||
state (:class:`~transformers.TrainerState`):
|
||||
The current state of the :class:`~transformers.Trainer`.
|
||||
control (:class:`~transformers.TrainerControl`):
|
||||
The object that is returned to the :class:`~transformers.Trainer` and can be used to make some decisions.
|
||||
model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`):
|
||||
The model being trained.
|
||||
optimizer (:obj:`torch.optim.Optimizer`):
|
||||
The optimizer used for the training steps.
|
||||
lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`):
|
||||
The scheduler used for setting the learning rate.
|
||||
train_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`):
|
||||
The current dataloader used for training.
|
||||
eval_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`):
|
||||
The current dataloader used for training.
|
||||
metrics (:obj:`Dict[str, float]`):
|
||||
The metrics computed by the last evaluation phase.
|
||||
|
||||
Those are only accessible in the event :obj:`on_evaluate`.
|
||||
logs (:obj:`Dict[str, float]`):
|
||||
The values to log.
|
||||
|
||||
Those are only accessible in the event :obj:`on_log`.
|
||||
|
||||
The :obj:`control` object is the only one that can be changed by the callback, in which case the event that changes
|
||||
it should return the modified version.
|
||||
|
||||
The argument :obj:`args`, :obj:`state` and :obj:`control` are positionals for all events, all the others are
|
||||
grouped in :obj:`kwargs`. You can unpack the ones you need in the signature of the event using them. As an example,
|
||||
see the code of the simple :class:`~transformer.PrinterCallback`.
|
||||
|
||||
Example::
|
||||
|
||||
class PrinterCallback(TrainerCallback):
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
_ = logs.pop("total_flos", None)
|
||||
if state.is_local_process_zero:
|
||||
print(logs)
|
||||
"""
|
||||
|
||||
def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
"""
|
||||
Event called at the end of the initialization of the :class:`~transformers.Trainer`.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
"""
|
||||
Event called at the beginning of training.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
"""
|
||||
Event called at the beginning of an epoch.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
"""
|
||||
Event called at the end of an epoch.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
"""
|
||||
Event called at the beginning of a training step. If using gradient accumulation, one training step might take
|
||||
several inputs.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
"""
|
||||
Event called at the end of a training step. If using gradient accumulation, one training step might take
|
||||
several inputs.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
"""
|
||||
Event called after an evaluation phase.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
"""
|
||||
Event called after a checkpoint save.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
"""
|
||||
Event called after logging the last logs.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
"""
|
||||
Event called after a prediction step.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class CallbackHandler(TrainerCallback):
|
||||
""" Internal class that just calls the list of callbacks in order. """
|
||||
|
||||
def __init__(self, callbacks, model, optimizer, lr_scheduler):
|
||||
self.callbacks = []
|
||||
for cb in callbacks:
|
||||
self.add_callback(cb)
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.train_dataloader = None
|
||||
self.eval_dataloader = None
|
||||
|
||||
if not any(isinstance(cb, DefaultFlowCallback) for cb in self.callbacks):
|
||||
logger.warn(
|
||||
"The Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You\n"
|
||||
+ "should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list of"
|
||||
+ "callbacks is\n:"
|
||||
+ self.callback_list
|
||||
)
|
||||
|
||||
def add_callback(self, callback):
|
||||
cb = callback() if isinstance(callback, type) else callback
|
||||
cb_class = callback if isinstance(callback, type) else callback.__class__
|
||||
if cb_class in [c.__class__ for c in self.callbacks]:
|
||||
logger.warn(
|
||||
f"You are adding a {cb_class} to the callbacks of this Trainer, but there is already one. The current"
|
||||
+ "list of callbacks is\n:"
|
||||
+ self.callback_list
|
||||
)
|
||||
self.callbacks.append(cb)
|
||||
|
||||
def pop_callback(self, callback):
|
||||
if isinstance(callback, type):
|
||||
for cb in self.callbacks:
|
||||
if isinstance(cb, callback):
|
||||
self.callbacks.remove(cb)
|
||||
return cb
|
||||
else:
|
||||
for cb in self.callbacks:
|
||||
if cb == callback:
|
||||
self.callbacks.remove(cb)
|
||||
return cb
|
||||
|
||||
def remove_callback(self, callback):
|
||||
if isinstance(callback, type):
|
||||
for cb in self.callbacks:
|
||||
if isinstance(cb, callback):
|
||||
self.callbacks.remove(cb)
|
||||
return
|
||||
else:
|
||||
self.callbacks.remove(callback)
|
||||
|
||||
@property
|
||||
def callback_list(self):
|
||||
return "\n".join(self.callbacks)
|
||||
|
||||
def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
||||
return self.call_event("on_init_end", args, state, control)
|
||||
|
||||
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
||||
control.should_training_stop = False
|
||||
return self.call_event("on_train_begin", args, state, control)
|
||||
|
||||
def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
||||
return self.call_event("on_train_end", args, state, control)
|
||||
|
||||
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
||||
control.should_epoch_stop = False
|
||||
return self.call_event("on_epoch_begin", args, state, control)
|
||||
|
||||
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
||||
return self.call_event("on_epoch_end", args, state, control)
|
||||
|
||||
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
||||
control.should_log = False
|
||||
control.should_evaluate = False
|
||||
control.should_save = False
|
||||
return self.call_event("on_step_begin", args, state, control)
|
||||
|
||||
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
||||
return self.call_event("on_step_end", args, state, control)
|
||||
|
||||
def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics):
|
||||
control.should_evaluate = False
|
||||
return self.call_event("on_evaluate", args, state, control, metrics=metrics)
|
||||
|
||||
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
||||
control.should_save = False
|
||||
return self.call_event("on_save", args, state, control)
|
||||
|
||||
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs):
|
||||
control.should_log = False
|
||||
return self.call_event("on_log", args, state, control, logs=logs)
|
||||
|
||||
def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
||||
return self.call_event("on_prediction_step", args, state, control)
|
||||
|
||||
def call_event(self, event, args, state, control, **kwargs):
|
||||
for callback in self.callbacks:
|
||||
result = getattr(callback, event)(
|
||||
args,
|
||||
state,
|
||||
control,
|
||||
model=self.model,
|
||||
optimizer=self.optimizer,
|
||||
lr_scheduler=self.lr_scheduler,
|
||||
train_dataloader=self.train_dataloader,
|
||||
eval_dataloader=self.eval_dataloader,
|
||||
**kwargs,
|
||||
)
|
||||
# A Callback can skip the return of `control` if it doesn't change it.
|
||||
if result is not None:
|
||||
control = result
|
||||
return control
|
||||
|
||||
|
||||
class DefaultFlowCallback(TrainerCallback):
|
||||
"""
|
||||
A :class:`~transformers.TrainerCallback` that handles the default flow of the training loop for logs, evaluation
|
||||
and checkpoints.
|
||||
"""
|
||||
|
||||
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
# Log
|
||||
if state.global_step == 1 and args.logging_first_step:
|
||||
control.should_log = True
|
||||
if args.logging_steps > 0 and state.global_step % args.logging_steps == 0:
|
||||
control.should_log = True
|
||||
|
||||
# Evaluate
|
||||
if args.evaluation_strategy == EvaluationStrategy.STEPS and state.global_step % args.eval_steps == 0:
|
||||
control.should_evaluate = True
|
||||
if args.load_best_model_at_end:
|
||||
control.should_save = True
|
||||
|
||||
# Save
|
||||
if not args.load_best_model_at_end and args.save_steps > 0 and state.global_step % args.save_steps == 0:
|
||||
control.should_save = True
|
||||
|
||||
# End training
|
||||
if state.global_step >= state.max_steps:
|
||||
control.should_training_stop = True
|
||||
|
||||
return control
|
||||
|
||||
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
if args.evaluation_strategy == EvaluationStrategy.EPOCH:
|
||||
control.should_evaluate = True
|
||||
if args.load_best_model_at_end:
|
||||
control.should_save = True
|
||||
return control
|
||||
|
||||
|
||||
class ProgressCallback(TrainerCallback):
|
||||
"""
|
||||
A :class:`~transformers.TrainerCallback` that displays the progress of training or evaluation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.training_bar = None
|
||||
self.prediction_bar = None
|
||||
|
||||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
if state.is_local_process_zero:
|
||||
self.training_bar = tqdm(total=state.max_steps)
|
||||
|
||||
def on_step_end(self, args, state, control, **kwargs):
|
||||
if state.is_local_process_zero:
|
||||
self.training_bar.update(1)
|
||||
|
||||
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
|
||||
if state.is_local_process_zero:
|
||||
if self.prediction_bar is None:
|
||||
self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None)
|
||||
self.prediction_bar.update(1)
|
||||
|
||||
def on_evaluate(self, args, state, control, **kwargs):
|
||||
if state.is_local_process_zero:
|
||||
self.prediction_bar.close()
|
||||
self.prediction_bar = None
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if state.is_local_process_zero and self.training_bar is not None:
|
||||
_ = logs.pop("total_flos", None)
|
||||
self.training_bar.write(str(logs))
|
||||
|
||||
def on_train_end(self, args, state, control, **kwargs):
|
||||
if state.is_local_process_zero:
|
||||
self.training_bar.close()
|
||||
self.training_bar = None
|
||||
|
||||
|
||||
class PrinterCallback(TrainerCallback):
|
||||
"""
|
||||
A bare :class:`~transformers.TrainerCallback` that just prints the logs.
|
||||
"""
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
_ = logs.pop("total_flos", None)
|
||||
if state.is_local_process_zero:
|
||||
print(logs)
|
||||
179
src/transformers/trainer_pt_utils.py
Normal file
179
src/transformers/trainer_pt_utils.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Torch utilities for the Trainer class.
|
||||
"""
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import RandomSampler, Sampler
|
||||
|
||||
from .file_utils import is_torch_tpu_available
|
||||
|
||||
|
||||
if is_torch_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."
|
||||
|
||||
|
||||
def nested_concat(tensors, new_tensors, dim=0):
|
||||
"Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors."
|
||||
assert type(tensors) == type(
|
||||
new_tensors
|
||||
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors))
|
||||
return torch.cat((tensors, new_tensors), dim=dim)
|
||||
|
||||
|
||||
def nested_numpify(tensors):
|
||||
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_numpify(t) for t in tensors)
|
||||
return tensors.cpu().numpy()
|
||||
|
||||
|
||||
def nested_detach(tensors):
|
||||
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_detach(t) for t in tensors)
|
||||
return tensors.detach()
|
||||
|
||||
|
||||
def nested_xla_mesh_reduce(tensors, name):
|
||||
if is_torch_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
|
||||
return xm.mesh_reduce(name, tensors, torch.cat)
|
||||
else:
|
||||
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
|
||||
|
||||
|
||||
def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> torch.Tensor:
|
||||
try:
|
||||
if isinstance(tensor, (tuple, list)):
|
||||
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
|
||||
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(output_tensors, tensor)
|
||||
concat = torch.cat(output_tensors, dim=0)
|
||||
|
||||
# truncate the dummy elements added by SequentialDistributedSampler
|
||||
if num_total_examples is not None:
|
||||
concat = concat[:num_total_examples]
|
||||
return concat
|
||||
except AssertionError:
|
||||
raise AssertionError("Not currently using distributed training")
|
||||
|
||||
|
||||
def distributed_broadcast_scalars(
|
||||
scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
try:
|
||||
tensorized_scalar = torch.tensor(scalars).cuda()
|
||||
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(output_tensors, tensorized_scalar)
|
||||
concat = torch.cat(output_tensors, dim=0)
|
||||
|
||||
# truncate the dummy elements added by SequentialDistributedSampler
|
||||
if num_total_examples is not None:
|
||||
concat = concat[:num_total_examples]
|
||||
return concat
|
||||
except AssertionError:
|
||||
raise AssertionError("Not currently using distributed training")
|
||||
|
||||
|
||||
def reissue_pt_warnings(caught_warnings):
|
||||
# Reissue warnings that are not the PT_LR_SCHEDULER_WARNING
|
||||
if len(caught_warnings) > 1:
|
||||
for w in caught_warnings:
|
||||
if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING:
|
||||
warnings.warn(w.message, w.category)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def torch_distributed_zero_first(local_rank: int):
|
||||
"""
|
||||
Decorator to make all processes in distributed training wait for each local_master to do something.
|
||||
|
||||
Args:
|
||||
local_rank (:obj:`int`): The rank of the local process.
|
||||
"""
|
||||
if local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier()
|
||||
yield
|
||||
if local_rank == 0:
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
class SequentialDistributedSampler(Sampler):
|
||||
"""
|
||||
Distributed Sampler that subsamples indicies sequentially,
|
||||
making it easier to collate all results at the end.
|
||||
|
||||
Even though we only use this sampler for eval and predict (no training),
|
||||
which means that the model params won't have to be synced (i.e. will not hang
|
||||
for synchronization even if varied number of forward passes), we still add extra
|
||||
samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
|
||||
to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, num_replicas=None, rank=None):
|
||||
if num_replicas is None:
|
||||
if not torch.distributed.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = torch.distributed.get_world_size()
|
||||
if rank is None:
|
||||
if not torch.distributed.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = torch.distributed.get_rank()
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
|
||||
def __iter__(self):
|
||||
indices = list(range(len(self.dataset)))
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[: (self.total_size - len(indices))]
|
||||
assert (
|
||||
len(indices) == self.total_size
|
||||
), f"Indices length {len(indices)} and total size {self.total_size} mismatched"
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
|
||||
assert (
|
||||
len(indices) == self.num_samples
|
||||
), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
|
||||
def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset):
|
||||
if xm.xrt_world_size() <= 1:
|
||||
return RandomSampler(dataset)
|
||||
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
||||
@@ -1,19 +1,30 @@
|
||||
import dataclasses
|
||||
import json
|
||||
# coding=utf-8
|
||||
# Copyright 2020-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utilities for the Trainer and TFTrainer class. Should be independent from PyTorch and TensorFlow.
|
||||
"""
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
|
||||
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .file_utils import is_tf_available, is_torch_available, is_torch_tpu_available
|
||||
from .file_utils import is_tf_available, is_torch_available
|
||||
from .tokenization_utils_base import ExplicitEnum
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
"""
|
||||
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf``
|
||||
@@ -139,144 +150,3 @@ default_hp_space = {
|
||||
HPSearchBackend.OPTUNA: default_hp_space_optuna,
|
||||
HPSearchBackend.RAY: default_hp_space_ray,
|
||||
}
|
||||
|
||||
|
||||
def nested_concat(tensors, new_tensors, dim=0):
|
||||
"Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors."
|
||||
if is_torch_available():
|
||||
assert type(tensors) == type(
|
||||
new_tensors
|
||||
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors))
|
||||
return torch.cat((tensors, new_tensors), dim=dim)
|
||||
else:
|
||||
raise ImportError("Torch must be installed to use `nested_concat`")
|
||||
|
||||
|
||||
def nested_deatch(tensors):
|
||||
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_detach(t) for t in tensors)
|
||||
return tensors.detach()
|
||||
|
||||
|
||||
def nested_numpify(tensors):
|
||||
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_numpify(t) for t in tensors)
|
||||
return tensors.cpu().numpy()
|
||||
|
||||
|
||||
def nested_detach(tensors):
|
||||
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_detach(t) for t in tensors)
|
||||
return tensors.detach()
|
||||
|
||||
|
||||
def nested_xla_mesh_reduce(tensors, name):
|
||||
if is_torch_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
|
||||
return xm.mesh_reduce(name, tensors, torch.cat)
|
||||
else:
|
||||
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
|
||||
|
||||
|
||||
def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> "torch.Tensor":
|
||||
if is_torch_available():
|
||||
try:
|
||||
if isinstance(tensor, (tuple, list)):
|
||||
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
|
||||
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(output_tensors, tensor)
|
||||
concat = torch.cat(output_tensors, dim=0)
|
||||
|
||||
# truncate the dummy elements added by SequentialDistributedSampler
|
||||
if num_total_examples is not None:
|
||||
concat = concat[:num_total_examples]
|
||||
return concat
|
||||
except AssertionError:
|
||||
raise AssertionError("Not currently using distributed training")
|
||||
else:
|
||||
raise ImportError("Torch must be installed to use `distributed_concat`")
|
||||
|
||||
|
||||
def distributed_broadcast_scalars(
|
||||
scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
|
||||
) -> "torch.Tensor":
|
||||
if is_torch_available():
|
||||
try:
|
||||
tensorized_scalar = torch.tensor(scalars).cuda()
|
||||
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(output_tensors, tensorized_scalar)
|
||||
concat = torch.cat(output_tensors, dim=0)
|
||||
|
||||
# truncate the dummy elements added by SequentialDistributedSampler
|
||||
if num_total_examples is not None:
|
||||
concat = concat[:num_total_examples]
|
||||
return concat
|
||||
except AssertionError:
|
||||
raise AssertionError("Not currently using distributed training")
|
||||
else:
|
||||
raise ImportError("Torch must be installed to use `distributed_broadcast_scalars`")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainerState:
|
||||
"""
|
||||
A class containing the `Trainer` inner state that will be saved along the model and optimizer.
|
||||
|
||||
.. note::
|
||||
|
||||
In all this class, one step is to be understood as one update step. When using gradient accumulation, one
|
||||
update step may require several forward and backward passes: if you use :obj:`gradient_accumulation_steps=n`,
|
||||
then one update step requires going throuch `n` batches.
|
||||
|
||||
Args:
|
||||
epoch (:obj:`float`, `optional`):
|
||||
Only set during training, will represent the epoch the training is at (the decimal part being the
|
||||
percentage of the current epoch completed).
|
||||
global_step (:obj:`int`, `optional`, defaults to 0):
|
||||
During training, represents the number of update steps completed.
|
||||
max_steps (:obj:`int`, `optional`, defaults to 0):
|
||||
The number of update steps to do during the current training.
|
||||
total_flos (:obj:`int`, `optional`, defaults to 0):
|
||||
The total number of floating operations done by the model since the beginning of training.
|
||||
log_history (:obj:`List[Dict[str, float]]`, `optional`):
|
||||
The list of logs done since the beginning of training.
|
||||
best_metric (:obj:`float`, `optional`):
|
||||
When tracking the best model, the value of the best metric encountered so far.
|
||||
best_model_checkpoint (:obj:`str`, `optional`):
|
||||
When tracking the best model, the value of the name of the checkpoint for the best model encountered so
|
||||
far.
|
||||
"""
|
||||
|
||||
epoch: Optional[float] = None
|
||||
global_step: int = 0
|
||||
max_steps: int = 0
|
||||
num_train_epochs: int = 0
|
||||
total_flos: int = 0
|
||||
log_history: List[Dict[str, float]] = None
|
||||
best_metric: Optional[float] = None
|
||||
best_model_checkpoint: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.log_history is None:
|
||||
self.log_history = []
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
""" Save the content of this instance in JSON format inside :obj:`json_path`."""
|
||||
json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
f.write(json_string)
|
||||
|
||||
@classmethod
|
||||
def load_from_json(cls, json_path: str):
|
||||
""" Create an instance from the content of :obj:`json_path`."""
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
return cls(**json.loads(text))
|
||||
|
||||
@@ -1869,19 +1869,10 @@ class MarianTokenizer:
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class EvalPrediction:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
def set_seed(*args, **kwargs):
|
||||
requires_pytorch(set_seed)
|
||||
|
||||
|
||||
def torch_distributed_zero_first(*args, **kwargs):
|
||||
requires_pytorch(torch_distributed_zero_first)
|
||||
|
||||
214
tests/test_trainer_callback.py
Normal file
214
tests/test_trainer_callback.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import (
|
||||
DefaultFlowCallback,
|
||||
EvaluationStrategy,
|
||||
PrinterCallback,
|
||||
ProgressCallback,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.trainer import DEFAULT_CALLBACKS
|
||||
|
||||
from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel
|
||||
|
||||
|
||||
class TestTrainerCallback(TrainerCallback):
|
||||
"A callback that registers the events that goes through."
|
||||
|
||||
def __init__(self):
|
||||
self.events = []
|
||||
|
||||
def on_init_end(self, args, state, control, **kwargs):
|
||||
self.events.append("on_init_end")
|
||||
|
||||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
self.events.append("on_train_begin")
|
||||
|
||||
def on_train_end(self, args, state, control, **kwargs):
|
||||
self.events.append("on_train_end")
|
||||
|
||||
def on_epoch_begin(self, args, state, control, **kwargs):
|
||||
self.events.append("on_epoch_begin")
|
||||
|
||||
def on_epoch_end(self, args, state, control, **kwargs):
|
||||
self.events.append("on_epoch_end")
|
||||
|
||||
def on_step_begin(self, args, state, control, **kwargs):
|
||||
self.events.append("on_step_begin")
|
||||
|
||||
def on_step_end(self, args, state, control, **kwargs):
|
||||
self.events.append("on_step_end")
|
||||
|
||||
def on_evaluate(self, args, state, control, **kwargs):
|
||||
self.events.append("on_evaluate")
|
||||
|
||||
def on_save(self, args, state, control, **kwargs):
|
||||
self.events.append("on_save")
|
||||
|
||||
def on_log(self, args, state, control, **kwargs):
|
||||
self.events.append("on_log")
|
||||
|
||||
def on_prediction_step(self, args, state, control, **kwargs):
|
||||
self.events.append("on_prediction_step")
|
||||
|
||||
|
||||
@require_torch
|
||||
class TrainerCallbackTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.output_dir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.output_dir)
|
||||
|
||||
def get_trainer(self, a=0, b=0, train_len=64, eval_len=64, callbacks=None, disable_tqdm=False, **kwargs):
|
||||
# disable_tqdm in TrainingArguments has a flaky default since it depends on the level of logging. We make sure
|
||||
# its set to False since the tests later on depend on its value.
|
||||
train_dataset = RegressionDataset(length=train_len)
|
||||
eval_dataset = RegressionDataset(length=eval_len)
|
||||
config = RegressionModelConfig(a=a, b=b)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
|
||||
args = TrainingArguments(self.output_dir, disable_tqdm=disable_tqdm, **kwargs)
|
||||
return Trainer(
|
||||
model,
|
||||
args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
def check_callbacks_equality(self, cbs1, cbs2):
|
||||
self.assertEqual(len(cbs1), len(cbs2))
|
||||
|
||||
# Order doesn't matter
|
||||
cbs1 = list(sorted(cbs1, key=lambda cb: cb.__name__ if isinstance(cb, type) else cb.__class__.__name__))
|
||||
cbs2 = list(sorted(cbs2, key=lambda cb: cb.__name__ if isinstance(cb, type) else cb.__class__.__name__))
|
||||
|
||||
for cb1, cb2 in zip(cbs1, cbs2):
|
||||
if isinstance(cb1, type) and isinstance(cb2, type):
|
||||
self.assertEqual(cb1, cb2)
|
||||
elif isinstance(cb1, type) and not isinstance(cb2, type):
|
||||
self.assertEqual(cb1, cb2.__class__)
|
||||
elif not isinstance(cb1, type) and isinstance(cb2, type):
|
||||
self.assertEqual(cb1.__class__, cb2)
|
||||
else:
|
||||
self.assertEqual(cb1, cb2)
|
||||
|
||||
def get_expected_events(self, trainer):
|
||||
expected_events = ["on_init_end", "on_train_begin"]
|
||||
step = 0
|
||||
train_dl_len = len(trainer.get_eval_dataloader())
|
||||
evaluation_events = ["on_prediction_step"] * len(trainer.get_eval_dataloader()) + ["on_log", "on_evaluate"]
|
||||
for _ in range(trainer.state.num_train_epochs):
|
||||
expected_events.append("on_epoch_begin")
|
||||
for _ in range(train_dl_len):
|
||||
step += 1
|
||||
expected_events += ["on_step_begin", "on_step_end"]
|
||||
if step % trainer.args.logging_steps == 0:
|
||||
expected_events.append("on_log")
|
||||
if (
|
||||
trainer.args.evaluation_strategy == EvaluationStrategy.STEPS
|
||||
and step % trainer.args.eval_steps == 0
|
||||
):
|
||||
expected_events += evaluation_events.copy()
|
||||
if step % trainer.args.save_steps == 0:
|
||||
expected_events.append("on_save")
|
||||
expected_events.append("on_epoch_end")
|
||||
if trainer.args.evaluation_strategy == EvaluationStrategy.EPOCH:
|
||||
expected_events += evaluation_events.copy()
|
||||
expected_events.append("on_train_end")
|
||||
return expected_events
|
||||
|
||||
def test_init_callback(self):
|
||||
trainer = self.get_trainer()
|
||||
expected_callbacks = DEFAULT_CALLBACKS.copy() + [ProgressCallback]
|
||||
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||
|
||||
# Callbacks passed at init are added to the default callbacks
|
||||
trainer = self.get_trainer(callbacks=[TestTrainerCallback])
|
||||
expected_callbacks.append(TestTrainerCallback)
|
||||
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||
|
||||
# TrainingArguments.disable_tqdm controls if use ProgressCallback or PrinterCallback
|
||||
trainer = self.get_trainer(disable_tqdm=True)
|
||||
expected_callbacks = DEFAULT_CALLBACKS.copy() + [PrinterCallback]
|
||||
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||
|
||||
def test_add_remove_callback(self):
|
||||
expected_callbacks = DEFAULT_CALLBACKS.copy() + [ProgressCallback]
|
||||
trainer = self.get_trainer()
|
||||
|
||||
# We can add, pop, or remove by class name
|
||||
trainer.remove_callback(DefaultFlowCallback)
|
||||
expected_callbacks.remove(DefaultFlowCallback)
|
||||
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||
|
||||
trainer = self.get_trainer()
|
||||
cb = trainer.pop_callback(DefaultFlowCallback)
|
||||
self.assertEqual(cb.__class__, DefaultFlowCallback)
|
||||
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||
|
||||
trainer.add_callback(DefaultFlowCallback)
|
||||
expected_callbacks.insert(0, DefaultFlowCallback)
|
||||
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||
|
||||
# We can also add, pop, or remove by instance
|
||||
trainer = self.get_trainer()
|
||||
cb = trainer.callback_handler.callbacks[0]
|
||||
trainer.remove_callback(cb)
|
||||
expected_callbacks.remove(DefaultFlowCallback)
|
||||
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||
|
||||
trainer = self.get_trainer()
|
||||
cb1 = trainer.callback_handler.callbacks[0]
|
||||
cb2 = trainer.pop_callback(cb1)
|
||||
self.assertEqual(cb1, cb2)
|
||||
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||
|
||||
trainer.add_callback(cb1)
|
||||
expected_callbacks.insert(0, DefaultFlowCallback)
|
||||
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||
|
||||
def test_event_flow(self):
|
||||
trainer = self.get_trainer(callbacks=[TestTrainerCallback])
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
|
||||
# Independent log/save/eval
|
||||
trainer = self.get_trainer(callbacks=[TestTrainerCallback], logging_steps=5)
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
|
||||
trainer = self.get_trainer(callbacks=[TestTrainerCallback], save_steps=5)
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
|
||||
trainer = self.get_trainer(callbacks=[TestTrainerCallback], eval_steps=5, evaluation_strategy="steps")
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
|
||||
trainer = self.get_trainer(callbacks=[TestTrainerCallback], evaluation_strategy="epoch")
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
|
||||
# A bit of everything
|
||||
trainer = self.get_trainer(
|
||||
callbacks=[TestTrainerCallback], logging_steps=3, save_steps=10, eval_steps=5, evaluation_strategy="steps"
|
||||
)
|
||||
trainer.train()
|
||||
events = trainer.callback_handler.callbacks[-2].events
|
||||
self.assertEqual(events, self.get_expected_events(trainer))
|
||||
Reference in New Issue
Block a user