Trainer callbacks (#7596)
* Initial callback proposal * Finish various callbacks * Post-rebase conflicts * Fix tests * Don't use something that's not set * Documentation * Remove unwanted print. * Document all models can work * Add tests + small fixes * Update docs/source/internal/trainer_utils.rst Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Address review comments * Fix TF tests * Real fix this time * This one should work * Fix typo * Really fix typo Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -213,6 +213,7 @@ conversion utilities for the following models:
|
|||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
:caption: Main Classes
|
:caption: Main Classes
|
||||||
|
|
||||||
|
main_classes/callback
|
||||||
main_classes/configuration
|
main_classes/configuration
|
||||||
main_classes/logging
|
main_classes/logging
|
||||||
main_classes/model
|
main_classes/model
|
||||||
@@ -270,3 +271,4 @@ conversion utilities for the following models:
|
|||||||
internal/modeling_utils
|
internal/modeling_utils
|
||||||
internal/pipelines_utils
|
internal/pipelines_utils
|
||||||
internal/tokenization_utils
|
internal/tokenization_utils
|
||||||
|
internal/trainer_utils
|
||||||
|
|||||||
21
docs/source/internal/trainer_utils.rst
Normal file
21
docs/source/internal/trainer_utils.rst
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
Utilities for Trainer
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
This page lists all the utility functions used by :class:`~transformers.Trainer`.
|
||||||
|
|
||||||
|
Most of those are only useful if you are studying the code of the Trainer in the library.
|
||||||
|
|
||||||
|
Utilities
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.EvalPrediction
|
||||||
|
|
||||||
|
.. autofunction:: transformers.set_seed
|
||||||
|
|
||||||
|
.. autofunction:: transformers.torch_distributed_zero_first
|
||||||
|
|
||||||
|
|
||||||
|
Callbacks internals
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.trainer_callback.CallbackHandler
|
||||||
68
docs/source/main_classes/callback.rst
Normal file
68
docs/source/main_classes/callback.rst
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
Callbacks
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
Callbacks are objects that can customize the behavior of the training loop in the PyTorch
|
||||||
|
:class:`~transformers.Trainer` (this feature is not yet implemented in TensorFlow) that can inspect the training loop
|
||||||
|
state (for progress reporting, logging on TensorBoard or other ML platforms...) and take decisions (like early
|
||||||
|
stopping).
|
||||||
|
|
||||||
|
Callbacks are "read only" pieces of code, apart from the :class:`~transformers.TrainerControl` object they return, they
|
||||||
|
cannot change anything in the training loop. For customizations that require changes in the training loop, you should
|
||||||
|
subclass :class:`~transformers.Trainer` and override the methods you need (see :doc:`trainer` for examples).
|
||||||
|
|
||||||
|
By default a :class:`~transformers.Trainer` will use the following callbacks:
|
||||||
|
|
||||||
|
- :class:`~transformers.DefaultFlowCallback` which handles the default beahvior for logging, saving and evaluation.
|
||||||
|
- :class:`~transformers.PrinterCallback` or :class:`~transformers.ProrgressCallback` to display progress and print the
|
||||||
|
logs (the first one is used if you deactivate tqdm through the :class:`~transformers.TrainingArguments`, otherwise
|
||||||
|
it's the second one).
|
||||||
|
- :class:`~transformers.integrations.TensorBoardCallback` if tensorboard is accessible (either through PyTorch >= 1.4
|
||||||
|
or tensorboardX).
|
||||||
|
- :class:`~transformers.integrations.WandbCallback` if `wandb <https://www.wandb.com/>`__ is installed.
|
||||||
|
- :class:`~transformers.integrations.CometCallback` if `comet_ml <https://www.comet.ml/site/>`__ is installed.
|
||||||
|
|
||||||
|
The main class that implements callbacks is :class:`~transformers.TrainerCallback`. It gets the
|
||||||
|
:class:`~transformers.TrainingArguments` used to instantiate the :class:`~transformers.Trainer`, can access that
|
||||||
|
Trainer's internal state via :class:`~transformers.TrainerState`, and can take some actions on the training loop via
|
||||||
|
:class:`~transformers.TrainerControl`.
|
||||||
|
|
||||||
|
|
||||||
|
Available Callbacks
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
Here is the list of the available :class:`~transformers.TrainerCallback` in the library:
|
||||||
|
|
||||||
|
.. autoclass:: transformers.integrations.CometCallback
|
||||||
|
:members: setup
|
||||||
|
|
||||||
|
.. autoclass:: transformers.DefaultFlowCallback
|
||||||
|
|
||||||
|
.. autoclass:: transformers.PrinterCallback
|
||||||
|
|
||||||
|
.. autoclass:: transformers.ProgressCallback
|
||||||
|
|
||||||
|
.. autoclass:: transformers.integrations.TensorBoardCallback
|
||||||
|
|
||||||
|
.. autoclass:: transformers.integrations.WandbCallback
|
||||||
|
:members: setup
|
||||||
|
|
||||||
|
|
||||||
|
TrainerCallback
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TrainerCallback
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
TrainerState
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TrainerState
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
TrainerControl
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TrainerControl
|
||||||
|
:members:
|
||||||
@@ -18,7 +18,6 @@ previous features. To inject custom behavior you can subclass them and override
|
|||||||
- **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaulation DataLoader (PyTorch) or TF Dataset.
|
- **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaulation DataLoader (PyTorch) or TF Dataset.
|
||||||
- **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset.
|
- **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset.
|
||||||
- **log** -- Logs information on the various objects watching training.
|
- **log** -- Logs information on the various objects watching training.
|
||||||
- **setup_wandb** -- Setups wandb (see `here <https://docs.wandb.com/huggingface>`__ for more information).
|
|
||||||
- **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at
|
- **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at
|
||||||
init.
|
init.
|
||||||
- **compute_loss** - Computes the loss on a batch of training inputs.
|
- **compute_loss** - Computes the loss on a batch of training inputs.
|
||||||
@@ -40,6 +39,10 @@ Here is an example of how to customize :class:`~transformers.Trainer` using a cu
|
|||||||
logits = outputs[0]
|
logits = outputs[0]
|
||||||
return my_custom_loss(logits, labels)
|
return my_custom_loss(logits, labels)
|
||||||
|
|
||||||
|
Another way to customize the training loop behavior for the PyTorch :class:`~transformers.Trainer` is to use
|
||||||
|
:doc:`callbacks <callback>` that can inspect the training loop state (for progress reporting, logging on TensorBoard or
|
||||||
|
other ML platforms...) and take decisions (like early stopping).
|
||||||
|
|
||||||
|
|
||||||
Trainer
|
Trainer
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
@@ -47,29 +50,23 @@ Trainer
|
|||||||
.. autoclass:: transformers.Trainer
|
.. autoclass:: transformers.Trainer
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
TFTrainer
|
TFTrainer
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
.. autoclass:: transformers.TFTrainer
|
.. autoclass:: transformers.TFTrainer
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
TrainingArguments
|
TrainingArguments
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
.. autoclass:: transformers.TrainingArguments
|
.. autoclass:: transformers.TrainingArguments
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
TFTrainingArguments
|
TFTrainingArguments
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
.. autoclass:: transformers.TFTrainingArguments
|
.. autoclass:: transformers.TFTrainingArguments
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
Utilities
|
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
|
|
||||||
.. autoclass:: transformers.EvalPrediction
|
|
||||||
|
|
||||||
.. autofunction:: transformers.set_seed
|
|
||||||
|
|
||||||
.. autofunction:: transformers.torch_distributed_zero_first
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from transformers import Trainer
|
|||||||
from transformers.configuration_fsmt import FSMTConfig
|
from transformers.configuration_fsmt import FSMTConfig
|
||||||
from transformers.file_utils import is_torch_tpu_available
|
from transformers.file_utils import is_torch_tpu_available
|
||||||
from transformers.optimization import Adafactor, AdamW, get_linear_schedule_with_warmup
|
from transformers.optimization import Adafactor, AdamW, get_linear_schedule_with_warmup
|
||||||
from transformers.trainer import get_tpu_sampler
|
from transformers.trainer_pt_utils import get_tpu_sampler
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ import tempfile
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from transformers.testing_utils import slow
|
from transformers.testing_utils import slow
|
||||||
from transformers.trainer_utils import TrainerState, set_seed
|
from transformers.trainer_callback import TrainerState
|
||||||
|
from transformers.trainer_utils import set_seed
|
||||||
|
|
||||||
from .finetune_trainer import main
|
from .finetune_trainer import main
|
||||||
from .test_seq2seq_examples import MBART_TINY
|
from .test_seq2seq_examples import MBART_TINY
|
||||||
|
|||||||
@@ -205,7 +205,15 @@ from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
|||||||
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
|
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
|
||||||
|
|
||||||
# Trainer
|
# Trainer
|
||||||
from .trainer_utils import EvalPrediction, TrainerState, set_seed
|
from .trainer_callback import (
|
||||||
|
DefaultFlowCallback,
|
||||||
|
PrinterCallback,
|
||||||
|
ProgressCallback,
|
||||||
|
TrainerCallback,
|
||||||
|
TrainerControl,
|
||||||
|
TrainerState,
|
||||||
|
)
|
||||||
|
from .trainer_utils import EvalPrediction, EvaluationStrategy, set_seed
|
||||||
from .training_args import TrainingArguments
|
from .training_args import TrainingArguments
|
||||||
from .training_args_tf import TFTrainingArguments
|
from .training_args_tf import TFTrainingArguments
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
@@ -529,7 +537,8 @@ if is_torch_available():
|
|||||||
from .tokenization_marian import MarianTokenizer
|
from .tokenization_marian import MarianTokenizer
|
||||||
|
|
||||||
# Trainer
|
# Trainer
|
||||||
from .trainer import EvalPrediction, Trainer, set_seed, torch_distributed_zero_first
|
from .trainer import Trainer
|
||||||
|
from .trainer_pt_utils import torch_distributed_zero_first
|
||||||
else:
|
else:
|
||||||
from .utils.dummy_pt_objects import *
|
from .utils.dummy_pt_objects import *
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,11 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from .file_utils import is_torch_tpu_available
|
||||||
|
from .trainer_callback import TrainerCallback
|
||||||
|
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import comet_ml # noqa: F401
|
import comet_ml # noqa: F401
|
||||||
@@ -36,15 +41,6 @@ try:
|
|||||||
except (ImportError):
|
except (ImportError):
|
||||||
_has_ray = False
|
_has_ray = False
|
||||||
|
|
||||||
|
|
||||||
# No ML framework or transformer imports above this point
|
|
||||||
|
|
||||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun # isort:skip
|
|
||||||
from .utils import logging # isort:skip
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.utils.tensorboard import SummaryWriter # noqa: F401
|
from torch.utils.tensorboard import SummaryWriter # noqa: F401
|
||||||
|
|
||||||
@@ -57,9 +53,10 @@ except ImportError:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
_has_tensorboard = False
|
_has_tensorboard = False
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Integration functions:
|
# Integration functions:
|
||||||
|
|
||||||
|
|
||||||
def is_wandb_available():
|
def is_wandb_available():
|
||||||
return _has_wandb
|
return _has_wandb
|
||||||
|
|
||||||
@@ -128,8 +125,8 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
|
|||||||
|
|
||||||
# The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
|
# The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
|
||||||
# while doing the ray hp search.
|
# while doing the ray hp search.
|
||||||
_tb_writer = trainer.tb_writer
|
|
||||||
trainer.tb_writer = None
|
_tb_writer = trainer.pop_callback(TensorBoardCallback)
|
||||||
trainer.model = None
|
trainer.model = None
|
||||||
# Setup default `resources_per_trial` and `reporter`.
|
# Setup default `resources_per_trial` and `reporter`.
|
||||||
if "resources_per_trial" not in kwargs and trainer.args.n_gpu > 0:
|
if "resources_per_trial" not in kwargs and trainer.args.n_gpu > 0:
|
||||||
@@ -182,5 +179,159 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
|
|||||||
analysis = ray.tune.run(_objective, config=trainer.hp_space(None), num_samples=n_trials, **kwargs)
|
analysis = ray.tune.run(_objective, config=trainer.hp_space(None), num_samples=n_trials, **kwargs)
|
||||||
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
|
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
|
||||||
best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
|
best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
|
||||||
trainer.tb_writer = _tb_writer
|
if _tb_writer is not None:
|
||||||
|
trainer.add_callback(_tb_writer)
|
||||||
return best_run
|
return best_run
|
||||||
|
|
||||||
|
|
||||||
|
class TensorBoardCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
A :class:`~transformers.TrainerCallback` that sends the logs to `TensorBoard
|
||||||
|
<https://www.tensorflow.org/tensorboard>`__.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tb_writer (:obj:`SummaryWriter`, `optional`):
|
||||||
|
The writer to use. Will instatiate one if not set.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tb_writer=None):
|
||||||
|
assert (
|
||||||
|
_has_tensorboard
|
||||||
|
), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
|
||||||
|
self.tb_writer = tb_writer
|
||||||
|
|
||||||
|
def on_init_end(self, args, state, control, **kwargs):
|
||||||
|
if self.tb_writer is None and state.is_world_process_zero:
|
||||||
|
self.tb_writer = SummaryWriter(log_dir=args.logging_dir)
|
||||||
|
|
||||||
|
def on_train_begin(self, args, state, control, **kwargs):
|
||||||
|
if self.tb_writer is not None:
|
||||||
|
self.tb_writer.add_text("args", args.to_json_string())
|
||||||
|
self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
|
||||||
|
|
||||||
|
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||||
|
if self.tb_writer:
|
||||||
|
for k, v in logs.items():
|
||||||
|
if isinstance(v, (int, float)):
|
||||||
|
self.tb_writer.add_scalar(k, v, state.global_step)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Trainer is attempting to log a value of "
|
||||||
|
'"%s" of type %s for key "%s" as a scalar. '
|
||||||
|
"This invocation of Tensorboard's writer.add_scalar() "
|
||||||
|
"is incorrect so we dropped this attribute.",
|
||||||
|
v,
|
||||||
|
type(v),
|
||||||
|
k,
|
||||||
|
)
|
||||||
|
self.tb_writer.flush()
|
||||||
|
|
||||||
|
def on_train_end(self, args, state, control, **kwargs):
|
||||||
|
if self.tb_writer:
|
||||||
|
self.tb_writer.close()
|
||||||
|
|
||||||
|
|
||||||
|
class WandbCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
A :class:`~transformers.TrainerCallback` that sends the logs to `Weight and Biases
|
||||||
|
<https://www.wandb.com/>`__.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
assert _has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`."
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
def setup(self, args, state, model):
|
||||||
|
"""
|
||||||
|
Setup the optional Weights & Biases (`wandb`) integration.
|
||||||
|
|
||||||
|
One can subclass and override this method to customize the setup if needed. Find more information
|
||||||
|
`here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
|
||||||
|
|
||||||
|
Environment:
|
||||||
|
WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`):
|
||||||
|
Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient
|
||||||
|
logging or :obj:`"all"` to log gradients and parameters.
|
||||||
|
WANDB_PROJECT (:obj:`str`, `optional`, defaults to :obj:`"huggingface"`):
|
||||||
|
Set this to a custom string to store results in a different project.
|
||||||
|
WANDB_DISABLED (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to disable wandb entirely.
|
||||||
|
"""
|
||||||
|
self._initialized = True
|
||||||
|
if state.is_world_process_zero:
|
||||||
|
logger.info(
|
||||||
|
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
|
||||||
|
)
|
||||||
|
combined_dict = {**args.to_sanitized_dict()}
|
||||||
|
if hasattr(model, "config"):
|
||||||
|
combined_dict = {**model.config.to_dict(), **combined_dict}
|
||||||
|
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=args.run_name)
|
||||||
|
# keep track of model topology and gradients, unsupported on TPU
|
||||||
|
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
|
||||||
|
wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps))
|
||||||
|
|
||||||
|
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||||
|
if not self._initialized:
|
||||||
|
self.setup(args, state, model)
|
||||||
|
|
||||||
|
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
|
||||||
|
if not self._initialized:
|
||||||
|
self.setup(args, state, model)
|
||||||
|
if state.is_world_process_zero:
|
||||||
|
wandb.log(logs, step=state.global_step)
|
||||||
|
|
||||||
|
|
||||||
|
class CometCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
A :class:`~transformers.TrainerCallback` that sends the logs to `Comet ML
|
||||||
|
<https://www.comet.ml/site/>`__.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
assert _has_comet, "CometCallback requires comet-ml to be installed. Run `pip install comet-ml`."
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
def setup(self, args, state, model):
|
||||||
|
"""
|
||||||
|
Setup the optional Comet.ml integration.
|
||||||
|
|
||||||
|
Environment:
|
||||||
|
COMET_MODE (:obj:`str`, `optional`):
|
||||||
|
"OFFLINE", "ONLINE", or "DISABLED"
|
||||||
|
COMET_PROJECT_NAME (:obj:`str`, `optional`):
|
||||||
|
Comet.ml project name for experiments
|
||||||
|
COMET_OFFLINE_DIRECTORY (:obj:`str`, `optional`):
|
||||||
|
Folder to use for saving offline experiments when :obj:`COMET_MODE` is "OFFLINE"
|
||||||
|
|
||||||
|
For a number of configurable items in the environment,
|
||||||
|
see `here <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__.
|
||||||
|
"""
|
||||||
|
self._initialized = True
|
||||||
|
if state.is_world_process_zero:
|
||||||
|
comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
|
||||||
|
args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
|
||||||
|
experiment = None
|
||||||
|
if comet_mode == "ONLINE":
|
||||||
|
experiment = comet_ml.Experiment(**args)
|
||||||
|
logger.info("Automatic Comet.ml online logging enabled")
|
||||||
|
elif comet_mode == "OFFLINE":
|
||||||
|
args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
|
||||||
|
experiment = comet_ml.OfflineExperiment(**args)
|
||||||
|
logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
|
||||||
|
if experiment is not None:
|
||||||
|
experiment._set_model_graph(model, framework="transformers")
|
||||||
|
experiment._log_parameters(args, prefix="args/", framework="transformers")
|
||||||
|
if hasattr(model, "config"):
|
||||||
|
experiment._log_parameters(model.config, prefix="config/", framework="transformers")
|
||||||
|
|
||||||
|
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||||
|
if not self._initialized:
|
||||||
|
self.setup(args, state, model)
|
||||||
|
|
||||||
|
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
|
||||||
|
if not self._initialized:
|
||||||
|
self.setup(args, state, model)
|
||||||
|
if state.is_world_process_zero:
|
||||||
|
experiment = comet_ml.config.get_global_experiment()
|
||||||
|
if experiment is not None:
|
||||||
|
experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")
|
||||||
|
|||||||
@@ -1,10 +1,26 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2020-present the HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
|
||||||
|
"""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -15,8 +31,7 @@ from torch import nn
|
|||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data.dataloader import DataLoader
|
||||||
from torch.utils.data.dataset import Dataset
|
from torch.utils.data.dataset import Dataset
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
|
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
||||||
from tqdm.auto import tqdm, trange
|
|
||||||
|
|
||||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||||
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_torch_tpu_available
|
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_torch_tpu_available
|
||||||
@@ -34,23 +49,35 @@ from .modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
|||||||
from .modeling_utils import PreTrainedModel
|
from .modeling_utils import PreTrainedModel
|
||||||
from .optimization import AdamW, get_linear_schedule_with_warmup
|
from .optimization import AdamW, get_linear_schedule_with_warmup
|
||||||
from .tokenization_utils_base import PreTrainedTokenizerBase
|
from .tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
from .trainer_utils import (
|
from .trainer_callback import (
|
||||||
PREFIX_CHECKPOINT_DIR,
|
CallbackHandler,
|
||||||
BestRun,
|
DefaultFlowCallback,
|
||||||
EvalPrediction,
|
PrinterCallback,
|
||||||
EvaluationStrategy,
|
ProgressCallback,
|
||||||
HPSearchBackend,
|
TrainerCallback,
|
||||||
PredictionOutput,
|
TrainerControl,
|
||||||
TrainerState,
|
TrainerState,
|
||||||
TrainOutput,
|
)
|
||||||
default_compute_objective,
|
from .trainer_pt_utils import (
|
||||||
default_hp_space,
|
SequentialDistributedSampler,
|
||||||
distributed_broadcast_scalars,
|
distributed_broadcast_scalars,
|
||||||
distributed_concat,
|
distributed_concat,
|
||||||
|
get_tpu_sampler,
|
||||||
nested_concat,
|
nested_concat,
|
||||||
nested_detach,
|
nested_detach,
|
||||||
nested_numpify,
|
nested_numpify,
|
||||||
nested_xla_mesh_reduce,
|
nested_xla_mesh_reduce,
|
||||||
|
reissue_pt_warnings,
|
||||||
|
)
|
||||||
|
from .trainer_utils import (
|
||||||
|
PREFIX_CHECKPOINT_DIR,
|
||||||
|
BestRun,
|
||||||
|
EvalPrediction,
|
||||||
|
HPSearchBackend,
|
||||||
|
PredictionOutput,
|
||||||
|
TrainOutput,
|
||||||
|
default_compute_objective,
|
||||||
|
default_hp_space,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from .training_args import TrainingArguments
|
from .training_args import TrainingArguments
|
||||||
@@ -60,7 +87,8 @@ from .utils import logging
|
|||||||
_use_native_amp = False
|
_use_native_amp = False
|
||||||
_use_apex = False
|
_use_apex = False
|
||||||
|
|
||||||
PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."
|
DEFAULT_CALLBACKS = [DefaultFlowCallback]
|
||||||
|
|
||||||
|
|
||||||
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
|
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
|
||||||
if version.parse(torch.__version__) < version.parse("1.6"):
|
if version.parse(torch.__version__) < version.parse("1.6"):
|
||||||
@@ -82,16 +110,20 @@ if is_torch_tpu_available():
|
|||||||
import torch_xla.distributed.parallel_loader as pl
|
import torch_xla.distributed.parallel_loader as pl
|
||||||
|
|
||||||
if is_tensorboard_available():
|
if is_tensorboard_available():
|
||||||
try:
|
from .integrations import TensorBoardCallback
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
except ImportError:
|
DEFAULT_CALLBACKS.append(TensorBoardCallback)
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
|
|
||||||
if is_wandb_available():
|
if is_wandb_available():
|
||||||
import wandb
|
from .integrations import WandbCallback
|
||||||
|
|
||||||
|
DEFAULT_CALLBACKS.append(WandbCallback)
|
||||||
|
|
||||||
if is_comet_available():
|
if is_comet_available():
|
||||||
import comet_ml
|
from .integrations import CometCallback
|
||||||
|
|
||||||
|
DEFAULT_CALLBACKS.append(CometCallback)
|
||||||
|
|
||||||
if is_optuna_available():
|
if is_optuna_available():
|
||||||
import optuna
|
import optuna
|
||||||
@@ -102,91 +134,20 @@ if is_ray_available():
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def reissue_pt_warnings(caught_warnings):
|
|
||||||
# Reissue warnings that are not the PT_LR_SCHEDULER_WARNING
|
|
||||||
if len(caught_warnings) > 1:
|
|
||||||
for w in caught_warnings:
|
|
||||||
if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING:
|
|
||||||
warnings.warn(w.message, w.category)
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def torch_distributed_zero_first(local_rank: int):
|
|
||||||
"""
|
|
||||||
Decorator to make all processes in distributed training wait for each local_master to do something.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
local_rank (:obj:`int`): The rank of the local process.
|
|
||||||
"""
|
|
||||||
if local_rank not in [-1, 0]:
|
|
||||||
torch.distributed.barrier()
|
|
||||||
yield
|
|
||||||
if local_rank == 0:
|
|
||||||
torch.distributed.barrier()
|
|
||||||
|
|
||||||
|
|
||||||
class SequentialDistributedSampler(Sampler):
|
|
||||||
"""
|
|
||||||
Distributed Sampler that subsamples indicies sequentially,
|
|
||||||
making it easier to collate all results at the end.
|
|
||||||
|
|
||||||
Even though we only use this sampler for eval and predict (no training),
|
|
||||||
which means that the model params won't have to be synced (i.e. will not hang
|
|
||||||
for synchronization even if varied number of forward passes), we still add extra
|
|
||||||
samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
|
|
||||||
to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dataset, num_replicas=None, rank=None):
|
|
||||||
if num_replicas is None:
|
|
||||||
if not torch.distributed.is_available():
|
|
||||||
raise RuntimeError("Requires distributed package to be available")
|
|
||||||
num_replicas = torch.distributed.get_world_size()
|
|
||||||
if rank is None:
|
|
||||||
if not torch.distributed.is_available():
|
|
||||||
raise RuntimeError("Requires distributed package to be available")
|
|
||||||
rank = torch.distributed.get_rank()
|
|
||||||
self.dataset = dataset
|
|
||||||
self.num_replicas = num_replicas
|
|
||||||
self.rank = rank
|
|
||||||
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
|
||||||
self.total_size = self.num_samples * self.num_replicas
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
indices = list(range(len(self.dataset)))
|
|
||||||
|
|
||||||
# add extra samples to make it evenly divisible
|
|
||||||
indices += indices[: (self.total_size - len(indices))]
|
|
||||||
assert (
|
|
||||||
len(indices) == self.total_size
|
|
||||||
), f"Indices length {len(indices)} and total size {self.total_size} mismatched"
|
|
||||||
|
|
||||||
# subsample
|
|
||||||
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
|
|
||||||
assert (
|
|
||||||
len(indices) == self.num_samples
|
|
||||||
), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
|
|
||||||
|
|
||||||
return iter(indices)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.num_samples
|
|
||||||
|
|
||||||
|
|
||||||
def get_tpu_sampler(dataset: Dataset):
|
|
||||||
if xm.xrt_world_size() <= 1:
|
|
||||||
return RandomSampler(dataset)
|
|
||||||
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
"""
|
"""
|
||||||
Trainer is a simple but feature-complete training and eval loop for PyTorch,
|
Trainer is a simple but feature-complete training and eval loop for PyTorch,
|
||||||
optimized for 🤗 Transformers.
|
optimized for 🤗 Transformers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (:class:`~transformers.PreTrainedModel`, `optional`):
|
model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`, `optional`):
|
||||||
The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
|
The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
:class:`~transformers.Trainer` is optimized to work with the :class:`~transformers.PreTrainedModel`
|
||||||
|
provided by the library. You can still use your own models defined as :obj:`torch.nn.Module` as long as
|
||||||
|
they work the same way as the 🤗 Transformers models.
|
||||||
args (:class:`~transformers.TrainingArguments`, `optional`):
|
args (:class:`~transformers.TrainingArguments`, `optional`):
|
||||||
The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
|
The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
|
||||||
with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
|
with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
|
||||||
@@ -210,8 +171,11 @@ class Trainer:
|
|||||||
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
|
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
|
||||||
The function that will be used to compute metrics at evaluation. Must take a
|
The function that will be used to compute metrics at evaluation. Must take a
|
||||||
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
|
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
|
||||||
tb_writer (:obj:`SummaryWriter`, `optional`):
|
callbacks (List of :obj:`~transformers.TrainerCallback`, `optional`):
|
||||||
Object to write to TensorBoard.
|
A list of callbacks to customize the training loop. Will add those to the list of default callbacks
|
||||||
|
detailed in :doc:`here <callback>`.
|
||||||
|
|
||||||
|
If you want to remove one of the default callbacks used, use the :meth:`Trainer.remove_callback` method.
|
||||||
optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`):
|
optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`):
|
||||||
A tuple containing the optimizer and the scheduler to use. Will default to an instance of
|
A tuple containing the optimizer and the scheduler to use. Will default to an instance of
|
||||||
:class:`~transformers.AdamW` on your model and a scheduler given by
|
:class:`~transformers.AdamW` on your model and a scheduler given by
|
||||||
@@ -222,7 +186,7 @@ class Trainer:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: PreTrainedModel = None,
|
model: Union[PreTrainedModel, torch.nn.Module] = None,
|
||||||
args: TrainingArguments = None,
|
args: TrainingArguments = None,
|
||||||
data_collator: Optional[DataCollator] = None,
|
data_collator: Optional[DataCollator] = None,
|
||||||
train_dataset: Optional[Dataset] = None,
|
train_dataset: Optional[Dataset] = None,
|
||||||
@@ -230,7 +194,7 @@ class Trainer:
|
|||||||
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
||||||
model_init: Callable[[], PreTrainedModel] = None,
|
model_init: Callable[[], PreTrainedModel] = None,
|
||||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||||
tb_writer: Optional["SummaryWriter"] = None,
|
callbacks: Optional[List[TrainerCallback]] = None,
|
||||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -259,7 +223,21 @@ class Trainer:
|
|||||||
"Passing a `model_init` is incompatible with providing the `optimizers` argument."
|
"Passing a `model_init` is incompatible with providing the `optimizers` argument."
|
||||||
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
||||||
)
|
)
|
||||||
self.tb_writer = tb_writer
|
callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
|
||||||
|
self.callback_handler = CallbackHandler(callbacks, self.model, self.optimizer, self.lr_scheduler)
|
||||||
|
self.add_callback(PrinterCallback if self.args.disable_tqdm else ProgressCallback)
|
||||||
|
|
||||||
|
# Deprecated arguments
|
||||||
|
if "tb_writer" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"Passing `tb_writer` as a keyword argument is deprecated and won't be possible in a "
|
||||||
|
+ "future version. Use `TensorBoardCallback(tb_writer=...)` instead and pass it to the `callbacks`"
|
||||||
|
+ "argument",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
tb_writer = kwargs.pop("tb_writer")
|
||||||
|
self.remove_callback(TensorBoardCallback)
|
||||||
|
self.add_callback(TensorBoardCallback(tb_writer=tb_writer))
|
||||||
if "prediction_loss_only" in kwargs:
|
if "prediction_loss_only" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a "
|
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a "
|
||||||
@@ -270,13 +248,6 @@ class Trainer:
|
|||||||
self.args.prediction_loss_only = kwargs.pop("prediction_loss_only")
|
self.args.prediction_loss_only = kwargs.pop("prediction_loss_only")
|
||||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||||
|
|
||||||
if tb_writer is None and is_tensorboard_available() and self.is_world_process_zero():
|
|
||||||
self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
|
|
||||||
if not is_tensorboard_available():
|
|
||||||
logger.warning(
|
|
||||||
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
|
# Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
|
||||||
self._loggers_initialized = False
|
self._loggers_initialized = False
|
||||||
|
|
||||||
@@ -304,6 +275,7 @@ class Trainer:
|
|||||||
self._remove_unused_columns(self.eval_dataset, description="evaluation")
|
self._remove_unused_columns(self.eval_dataset, description="evaluation")
|
||||||
|
|
||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
|
self.control = TrainerControl()
|
||||||
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
|
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
|
||||||
# state at each call to self.log.
|
# state at each call to self.log.
|
||||||
self._total_flos = None
|
self._total_flos = None
|
||||||
@@ -317,6 +289,45 @@ class Trainer:
|
|||||||
else ["labels"]
|
else ["labels"]
|
||||||
)
|
)
|
||||||
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
|
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
|
||||||
|
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
|
||||||
|
|
||||||
|
def add_callback(self, callback):
|
||||||
|
"""
|
||||||
|
Add a callback to the current list of :class:`~transformer.TrainerCallback`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
|
||||||
|
A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
|
||||||
|
In the first case, will instantiate a member of that class.
|
||||||
|
"""
|
||||||
|
self.callback_handler.add_callback(callback)
|
||||||
|
|
||||||
|
def pop_callback(self, callback):
|
||||||
|
"""
|
||||||
|
Remove a callback from the current list of :class:`~transformer.TrainerCallback` and returns it.
|
||||||
|
|
||||||
|
If the callback is not found, returns :obj:`None` (and no error is raised).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
|
||||||
|
A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
|
||||||
|
In the first case, will pop the first member of that class found in the list of callbacks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:class:`~transformer.TrainerCallback`: The callback removed, if found.
|
||||||
|
"""
|
||||||
|
return self.callback_handler.pop_callback(callback)
|
||||||
|
|
||||||
|
def remove_callback(self, callback):
|
||||||
|
"""
|
||||||
|
Remove a callback from the current list of :class:`~transformer.TrainerCallback`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
|
||||||
|
A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
|
||||||
|
In the first case, will remove the first member of that class found in the list of callbacks.
|
||||||
|
"""
|
||||||
|
self.callback_handler.remove_callback(callback)
|
||||||
|
|
||||||
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
|
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
|
||||||
if not self.args.remove_unused_columns:
|
if not self.args.remove_unused_columns:
|
||||||
@@ -465,102 +476,12 @@ class Trainer:
|
|||||||
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
|
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
|
||||||
)
|
)
|
||||||
|
|
||||||
def setup_wandb(self):
|
|
||||||
"""
|
|
||||||
Setup the optional Weights & Biases (`wandb`) integration.
|
|
||||||
|
|
||||||
One can subclass and override this method to customize the setup if needed. Find more information
|
|
||||||
`here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
|
|
||||||
|
|
||||||
Environment:
|
|
||||||
WANDB_WATCH:
|
|
||||||
(Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
|
|
||||||
or "all" to log gradients and parameters
|
|
||||||
WANDB_PROJECT:
|
|
||||||
(Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
|
|
||||||
WANDB_DISABLED:
|
|
||||||
(Optional): boolean - defaults to false, set to "true" to disable wandb entirely
|
|
||||||
"""
|
|
||||||
if hasattr(self, "_setup_wandb"):
|
|
||||||
warnings.warn(
|
|
||||||
"The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
return self._setup_wandb()
|
|
||||||
|
|
||||||
if self.is_world_process_zero():
|
|
||||||
logger.info(
|
|
||||||
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
|
|
||||||
)
|
|
||||||
combined_dict = {**self.args.to_sanitized_dict()}
|
|
||||||
if isinstance(self.model, PreTrainedModel):
|
|
||||||
combined_dict = {**self.model.config.to_dict(), **combined_dict}
|
|
||||||
wandb.init(
|
|
||||||
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
|
|
||||||
)
|
|
||||||
# keep track of model topology and gradients, unsupported on TPU
|
|
||||||
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
|
|
||||||
wandb.watch(
|
|
||||||
self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
|
|
||||||
)
|
|
||||||
|
|
||||||
def setup_comet(self):
|
|
||||||
"""
|
|
||||||
Setup the optional Comet.ml integration.
|
|
||||||
|
|
||||||
Environment:
|
|
||||||
COMET_MODE:
|
|
||||||
(Optional): str - "OFFLINE", "ONLINE", or "DISABLED"
|
|
||||||
COMET_PROJECT_NAME:
|
|
||||||
(Optional): str - Comet.ml project name for experiments
|
|
||||||
COMET_OFFLINE_DIRECTORY:
|
|
||||||
(Optional): str - folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE"
|
|
||||||
|
|
||||||
For a number of configurable items in the environment,
|
|
||||||
see `here <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__
|
|
||||||
"""
|
|
||||||
if self.is_world_master():
|
|
||||||
comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
|
|
||||||
args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
|
|
||||||
experiment = None
|
|
||||||
if comet_mode == "ONLINE":
|
|
||||||
experiment = comet_ml.Experiment(**args)
|
|
||||||
logger.info("Automatic Comet.ml online logging enabled")
|
|
||||||
elif comet_mode == "OFFLINE":
|
|
||||||
args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
|
|
||||||
experiment = comet_ml.OfflineExperiment(**args)
|
|
||||||
logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
|
|
||||||
if experiment is not None:
|
|
||||||
experiment._set_model_graph(self.model, framework="transformers")
|
|
||||||
experiment._log_parameters(self.args, prefix="args/", framework="transformers")
|
|
||||||
if isinstance(self.model, PreTrainedModel):
|
|
||||||
experiment._log_parameters(self.model.config, prefix="config/", framework="transformers")
|
|
||||||
|
|
||||||
def num_examples(self, dataloader: DataLoader) -> int:
|
def num_examples(self, dataloader: DataLoader) -> int:
|
||||||
"""
|
"""
|
||||||
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
|
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
|
||||||
"""
|
"""
|
||||||
return len(dataloader.dataset)
|
return len(dataloader.dataset)
|
||||||
|
|
||||||
def _setup_loggers(self):
|
|
||||||
if self._loggers_initialized:
|
|
||||||
return
|
|
||||||
if is_wandb_available():
|
|
||||||
self.setup_wandb()
|
|
||||||
elif os.environ.get("WANDB_DISABLED") != "true":
|
|
||||||
logger.info(
|
|
||||||
"You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
|
|
||||||
"run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
|
|
||||||
)
|
|
||||||
if is_comet_available():
|
|
||||||
self.setup_comet()
|
|
||||||
elif os.environ.get("COMET_MODE") != "DISABLED":
|
|
||||||
logger.info(
|
|
||||||
"To use comet_ml logging, run `pip/conda install comet_ml` "
|
|
||||||
"see https://www.comet.ml/docs/python-sdk/huggingface/"
|
|
||||||
)
|
|
||||||
self._loggers_initialized = True
|
|
||||||
|
|
||||||
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
|
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
|
||||||
""" HP search setup code """
|
""" HP search setup code """
|
||||||
if self.hp_search_backend is None or trial is None:
|
if self.hp_search_backend is None or trial is None:
|
||||||
@@ -661,7 +582,7 @@ class Trainer:
|
|||||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
|
||||||
# Moxed precision training with apex (torch < 1.6)
|
# Mixed precision training with apex (torch < 1.6)
|
||||||
model = self.model
|
model = self.model
|
||||||
if self.args.fp16 and _use_apex:
|
if self.args.fp16 and _use_apex:
|
||||||
if not is_apex_available():
|
if not is_apex_available():
|
||||||
@@ -687,10 +608,6 @@ class Trainer:
|
|||||||
# find_unused_parameters breaks checkpointing as per
|
# find_unused_parameters breaks checkpointing as per
|
||||||
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
|
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
|
||||||
|
|
||||||
if self.tb_writer is not None:
|
|
||||||
self.tb_writer.add_text("args", self.args.to_json_string())
|
|
||||||
self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
|
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
|
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
|
||||||
@@ -723,17 +640,25 @@ class Trainer:
|
|||||||
logger.info(" Continuing training from global step %d", self.state.global_step)
|
logger.info(" Continuing training from global step %d", self.state.global_step)
|
||||||
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
||||||
|
|
||||||
|
# Update the references
|
||||||
|
self.callback_handler.model = self.model
|
||||||
|
self.callback_handler.optimizer = self.optimizer
|
||||||
|
self.callback_handler.lr_scheduler = self.lr_scheduler
|
||||||
|
self.callback_handler.train_dataloader = train_dataloader
|
||||||
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
|
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
|
||||||
# to set this after the load.
|
# to set this after the load.
|
||||||
self.state.max_steps = max_steps
|
self.state.max_steps = max_steps
|
||||||
self.state.num_train_epochs = num_train_epochs
|
self.state.num_train_epochs = num_train_epochs
|
||||||
|
self.state.is_local_process_zero = self.is_local_process_zero()
|
||||||
|
self.state.is_world_process_zero = self.is_world_process_zero()
|
||||||
|
|
||||||
tr_loss = torch.tensor(0.0).to(self.args.device)
|
tr_loss = torch.tensor(0.0).to(self.args.device)
|
||||||
|
self._logging_loss_scalar = 0
|
||||||
self._total_flos = self.state.total_flos
|
self._total_flos = self.state.total_flos
|
||||||
logging_loss_scalar = 0.0
|
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
|
|
||||||
train_pbar = trange(epochs_trained, num_train_epochs, desc="Epoch", disable=disable_tqdm)
|
self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control)
|
||||||
|
|
||||||
for epoch in range(epochs_trained, num_train_epochs):
|
for epoch in range(epochs_trained, num_train_epochs):
|
||||||
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
||||||
train_dataloader.sampler.set_epoch(epoch)
|
train_dataloader.sampler.set_epoch(epoch)
|
||||||
@@ -750,15 +675,18 @@ class Trainer:
|
|||||||
if self.args.past_index >= 0:
|
if self.args.past_index >= 0:
|
||||||
self._past = None
|
self._past = None
|
||||||
|
|
||||||
epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm)
|
self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)
|
||||||
|
|
||||||
for step, inputs in enumerate(epoch_iterator):
|
for step, inputs in enumerate(epoch_iterator):
|
||||||
|
|
||||||
# Skip past any already trained steps if resuming training
|
# Skip past any already trained steps if resuming training
|
||||||
if steps_trained_in_current_epoch > 0:
|
if steps_trained_in_current_epoch > 0:
|
||||||
steps_trained_in_current_epoch -= 1
|
steps_trained_in_current_epoch -= 1
|
||||||
epoch_pbar.update(1)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if (step + 1) % self.args.gradient_accumulation_steps == 0:
|
||||||
|
self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)
|
||||||
|
|
||||||
tr_loss += self.training_step(model, inputs)
|
tr_loss += self.training_step(model, inputs)
|
||||||
self._total_flos += self.floating_point_ops(inputs)
|
self._total_flos += self.floating_point_ops(inputs)
|
||||||
|
|
||||||
@@ -787,50 +715,15 @@ class Trainer:
|
|||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
self.state.global_step += 1
|
self.state.global_step += 1
|
||||||
self.state.epoch = epoch + (step + 1) / len(epoch_iterator)
|
self.state.epoch = epoch + (step + 1) / len(epoch_iterator)
|
||||||
|
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
|
||||||
|
|
||||||
if (self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0) or (
|
self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
|
||||||
self.state.global_step == 1 and self.args.logging_first_step
|
|
||||||
):
|
|
||||||
logs: Dict[str, float] = {}
|
|
||||||
tr_loss_scalar = tr_loss.item()
|
|
||||||
logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps
|
|
||||||
# backward compatibility for pytorch schedulers
|
|
||||||
logs["learning_rate"] = (
|
|
||||||
self.lr_scheduler.get_last_lr()[0]
|
|
||||||
if version.parse(torch.__version__) >= version.parse("1.4")
|
|
||||||
else self.lr_scheduler.get_lr()[0]
|
|
||||||
)
|
|
||||||
logging_loss_scalar = tr_loss_scalar
|
|
||||||
|
|
||||||
self.log(logs)
|
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||||
|
|
||||||
if (
|
|
||||||
self.args.evaluation_strategy == EvaluationStrategy.STEPS
|
|
||||||
and self.state.global_step % self.args.eval_steps == 0
|
|
||||||
):
|
|
||||||
metrics = self.evaluate()
|
|
||||||
self._report_to_hp_search(trial, epoch, metrics)
|
|
||||||
if self.args.load_best_model_at_end:
|
|
||||||
self._save_training(model, trial, metrics=metrics)
|
|
||||||
|
|
||||||
if (
|
|
||||||
not self.args.load_best_model_at_end
|
|
||||||
and self.args.save_steps > 0
|
|
||||||
and self.state.global_step % self.args.save_steps == 0
|
|
||||||
):
|
|
||||||
self._save_training(model, trial)
|
|
||||||
|
|
||||||
epoch_pbar.update(1)
|
|
||||||
if self.state.global_step >= max_steps:
|
|
||||||
break
|
break
|
||||||
epoch_pbar.close()
|
|
||||||
train_pbar.update(1)
|
|
||||||
|
|
||||||
if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
|
self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
|
||||||
metrics = self.evaluate()
|
self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
|
||||||
self._report_to_hp_search(trial, epoch, metrics)
|
|
||||||
if self.args.load_best_model_at_end:
|
|
||||||
self._save_training(model, trial, metrics=metrics)
|
|
||||||
|
|
||||||
if self.args.tpu_metrics_debug or self.args.debug:
|
if self.args.tpu_metrics_debug or self.args.debug:
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
@@ -841,12 +734,9 @@ class Trainer:
|
|||||||
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
|
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
|
||||||
"configured. Check your training configuration if this is unexpected."
|
"configured. Check your training configuration if this is unexpected."
|
||||||
)
|
)
|
||||||
if self.state.global_step >= max_steps:
|
if self.control.should_training_stop:
|
||||||
break
|
break
|
||||||
|
|
||||||
train_pbar.close()
|
|
||||||
if self.tb_writer:
|
|
||||||
self.tb_writer.close()
|
|
||||||
if self.args.past_index and hasattr(self, "_past"):
|
if self.args.past_index and hasattr(self, "_past"):
|
||||||
# Clean the state at the end of training
|
# Clean the state at the end of training
|
||||||
delattr(self, "_past")
|
delattr(self, "_past")
|
||||||
@@ -863,9 +753,36 @@ class Trainer:
|
|||||||
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
|
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
|
||||||
self.model.load_state_dict(state_dict)
|
self.model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
|
||||||
|
|
||||||
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
|
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
|
||||||
|
|
||||||
def _save_training(self, model, trial, metrics=None):
|
def _maybe_log_save_evalute(self, tr_loss, model, trial, epoch):
|
||||||
|
if self.control.should_log:
|
||||||
|
logs: Dict[str, float] = {}
|
||||||
|
tr_loss_scalar = tr_loss.item()
|
||||||
|
logs["loss"] = (tr_loss_scalar - self._logging_loss_scalar) / self.args.logging_steps
|
||||||
|
# backward compatibility for pytorch schedulers
|
||||||
|
logs["learning_rate"] = (
|
||||||
|
self.lr_scheduler.get_last_lr()[0]
|
||||||
|
if version.parse(torch.__version__) >= version.parse("1.4")
|
||||||
|
else self.lr_scheduler.get_lr()[0]
|
||||||
|
)
|
||||||
|
self._logging_loss_scalar = tr_loss_scalar
|
||||||
|
|
||||||
|
self.log(logs)
|
||||||
|
|
||||||
|
metrics = None
|
||||||
|
if self.control.should_evaluate:
|
||||||
|
metrics = self.evaluate()
|
||||||
|
self._report_to_hp_search(trial, epoch, metrics)
|
||||||
|
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
||||||
|
|
||||||
|
if self.control.should_save:
|
||||||
|
self._save_checkpoint(model, trial, metrics=metrics)
|
||||||
|
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
||||||
|
|
||||||
|
def _save_checkpoint(self, model, trial, metrics=None):
|
||||||
# In all cases (even distributed/parallel), self.model is always a reference
|
# In all cases (even distributed/parallel), self.model is always a reference
|
||||||
# to the model we want to save.
|
# to the model we want to save.
|
||||||
if hasattr(model, "module"):
|
if hasattr(model, "module"):
|
||||||
@@ -896,7 +813,7 @@ class Trainer:
|
|||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
|
||||||
# Determine the new best metric / best model checkpoint
|
# Determine the new best metric / best model checkpoint
|
||||||
if metrics is not None:
|
if metrics is not None and self.args.metric_for_best_model is not None:
|
||||||
metric_to_check = self.args.metric_for_best_model
|
metric_to_check = self.args.metric_for_best_model
|
||||||
if not metric_to_check.startswith("eval_"):
|
if not metric_to_check.startswith("eval_"):
|
||||||
metric_to_check = f"eval_{metric_to_check}"
|
metric_to_check = f"eval_{metric_to_check}"
|
||||||
@@ -998,7 +915,7 @@ class Trainer:
|
|||||||
self.hp_search_backend = None
|
self.hp_search_backend = None
|
||||||
return best_run
|
return best_run
|
||||||
|
|
||||||
def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
|
def log(self, logs: Dict[str, float]) -> None:
|
||||||
"""
|
"""
|
||||||
Log :obj:`logs` on the various objects watching training.
|
Log :obj:`logs` on the various objects watching training.
|
||||||
|
|
||||||
@@ -1007,55 +924,22 @@ class Trainer:
|
|||||||
Args:
|
Args:
|
||||||
logs (:obj:`Dict[str, float]`):
|
logs (:obj:`Dict[str, float]`):
|
||||||
The values to log.
|
The values to log.
|
||||||
iterator (:obj:`tqdm`, `optional`):
|
|
||||||
A potential tqdm progress bar to write the logs on.
|
|
||||||
"""
|
"""
|
||||||
# Set up loggers like W&B or Comet ML
|
|
||||||
self._setup_loggers()
|
|
||||||
|
|
||||||
if hasattr(self, "_log"):
|
if hasattr(self, "_log"):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
|
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
return self._log(logs, iterator=iterator)
|
return self._log(logs)
|
||||||
|
|
||||||
if self.state.epoch is not None:
|
if self.state.epoch is not None:
|
||||||
logs["epoch"] = self.state.epoch
|
logs["epoch"] = self.state.epoch
|
||||||
if self._total_flos is not None:
|
if self._total_flos is not None:
|
||||||
self.store_flos()
|
self.store_flos()
|
||||||
logs["total_flos"] = self.state.total_flos
|
logs["total_flos"] = self.state.total_flos
|
||||||
if self.tb_writer:
|
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
|
||||||
for k, v in logs.items():
|
|
||||||
if isinstance(v, (int, float)):
|
|
||||||
self.tb_writer.add_scalar(k, v, self.state.global_step)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"Trainer is attempting to log a value of "
|
|
||||||
'"%s" of type %s for key "%s" as a scalar. '
|
|
||||||
"This invocation of Tensorboard's writer.add_scalar() "
|
|
||||||
"is incorrect so we dropped this attribute.",
|
|
||||||
v,
|
|
||||||
type(v),
|
|
||||||
k,
|
|
||||||
)
|
|
||||||
self.tb_writer.flush()
|
|
||||||
if is_wandb_available():
|
|
||||||
if self.is_world_process_zero():
|
|
||||||
wandb.log(logs, step=self.state.global_step)
|
|
||||||
if is_comet_available():
|
|
||||||
if self.is_world_process_zero():
|
|
||||||
experiment = comet_ml.config.get_global_experiment()
|
|
||||||
if experiment is not None:
|
|
||||||
experiment._log_metrics(
|
|
||||||
logs, step=self.state.global_step, epoch=self.state.epoch, framework="transformers"
|
|
||||||
)
|
|
||||||
output = {**logs, **{"step": self.state.global_step}}
|
output = {**logs, **{"step": self.state.global_step}}
|
||||||
self.state.log_history.append(output)
|
self.state.log_history.append(output)
|
||||||
if iterator is not None:
|
|
||||||
iterator.write(output)
|
|
||||||
else:
|
|
||||||
print(output)
|
|
||||||
|
|
||||||
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
|
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
|
||||||
"""
|
"""
|
||||||
@@ -1372,8 +1256,9 @@ class Trainer:
|
|||||||
if self.args.past_index >= 0:
|
if self.args.past_index >= 0:
|
||||||
self._past = None
|
self._past = None
|
||||||
|
|
||||||
disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
|
self.callback_handler.eval_dataloader = dataloader
|
||||||
for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
|
|
||||||
|
for inputs in dataloader:
|
||||||
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
|
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
|
||||||
batch_size = inputs[list(inputs.keys())[0]].shape[0]
|
batch_size = inputs[list(inputs.keys())[0]].shape[0]
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
@@ -1382,6 +1267,7 @@ class Trainer:
|
|||||||
preds = logits if preds is None else nested_concat(preds, logits, dim=0)
|
preds = logits if preds is None else nested_concat(preds, logits, dim=0)
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0)
|
label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0)
|
||||||
|
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
|
||||||
|
|
||||||
if self.args.past_index and hasattr(self, "_past"):
|
if self.args.past_index and hasattr(self, "_past"):
|
||||||
# Clean the state at the end of the evaluation loop
|
# Clean the state at the end of the evaluation loop
|
||||||
|
|||||||
468
src/transformers/trainer_callback.py
Normal file
468
src/transformers/trainer_callback.py
Normal file
@@ -0,0 +1,468 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2020-present the HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Callbacks to use with the Trainer class and customize the training loop.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
from .trainer_utils import EvaluationStrategy
|
||||||
|
from .training_args import TrainingArguments
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainerState:
|
||||||
|
"""
|
||||||
|
A class containing the :class:`~transformers.Trainer` inner state that will be saved along the model and optimizer
|
||||||
|
when checkpointing and passed to the :class:`~transformers.TrainerCallback`.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
In all this class, one step is to be understood as one update step. When using gradient accumulation, one
|
||||||
|
update step may require several forward and backward passes: if you use :obj:`gradient_accumulation_steps=n`,
|
||||||
|
then one update step requires going throuch `n` batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
epoch (:obj:`float`, `optional`):
|
||||||
|
Only set during training, will represent the epoch the training is at (the decimal part being the
|
||||||
|
percentage of the current epoch completed).
|
||||||
|
global_step (:obj:`int`, `optional`, defaults to 0):
|
||||||
|
During training, represents the number of update steps completed.
|
||||||
|
max_steps (:obj:`int`, `optional`, defaults to 0):
|
||||||
|
The number of update steps to do during the current training.
|
||||||
|
total_flos (:obj:`int`, `optional`, defaults to 0):
|
||||||
|
The total number of floating operations done by the model since the beginning of training.
|
||||||
|
log_history (:obj:`List[Dict[str, float]]`, `optional`):
|
||||||
|
The list of logs done since the beginning of training.
|
||||||
|
best_metric (:obj:`float`, `optional`):
|
||||||
|
When tracking the best model, the value of the best metric encountered so far.
|
||||||
|
best_model_checkpoint (:obj:`str`, `optional`):
|
||||||
|
When tracking the best model, the value of the name of the checkpoint for the best model encountered so
|
||||||
|
far.
|
||||||
|
is_local_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
|
||||||
|
several machines) main process.
|
||||||
|
is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not this process is the global main process (when training in a distributed fashion on
|
||||||
|
several machines, this is only going to be :obj:`True` for one process).
|
||||||
|
"""
|
||||||
|
|
||||||
|
epoch: Optional[float] = None
|
||||||
|
global_step: int = 0
|
||||||
|
max_steps: int = 0
|
||||||
|
num_train_epochs: int = 0
|
||||||
|
total_flos: int = 0
|
||||||
|
log_history: List[Dict[str, float]] = None
|
||||||
|
best_metric: Optional[float] = None
|
||||||
|
best_model_checkpoint: Optional[str] = None
|
||||||
|
is_local_process_zero: bool = True
|
||||||
|
is_world_process_zero: bool = True
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.log_history is None:
|
||||||
|
self.log_history = []
|
||||||
|
|
||||||
|
def save_to_json(self, json_path: str):
|
||||||
|
""" Save the content of this instance in JSON format inside :obj:`json_path`."""
|
||||||
|
json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(json_string)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_from_json(cls, json_path: str):
|
||||||
|
""" Create an instance from the content of :obj:`json_path`."""
|
||||||
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
|
text = f.read()
|
||||||
|
return cls(**json.loads(text))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainerControl:
|
||||||
|
"""
|
||||||
|
A class that handles the :class:`~transformers.Trainer` control flow. This class is used by the
|
||||||
|
:class:`~transformers.TrainerCallback` to activate some switches in the training loop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
should_training_stop (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not the training should be interrupted.
|
||||||
|
|
||||||
|
If :obj:`True`, this variable will not be set back to :obj:`False`. The training will just stop.
|
||||||
|
should_epoch_stop (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not the current epoch should be interrupted.
|
||||||
|
|
||||||
|
If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next epoch.
|
||||||
|
should_save (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not the model should be saved at this step.
|
||||||
|
|
||||||
|
If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next step.
|
||||||
|
should_evaluate (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not the model should be evaluated at this step.
|
||||||
|
|
||||||
|
If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next step.
|
||||||
|
should_log (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not the logs should be reported at this step.
|
||||||
|
|
||||||
|
If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
should_training_stop: bool = False
|
||||||
|
should_epoch_stop: bool = False
|
||||||
|
should_save: bool = False
|
||||||
|
should_evaluate: bool = False
|
||||||
|
should_log: bool = False
|
||||||
|
|
||||||
|
def _new_training(self):
|
||||||
|
""" Internal method that resets the variable for a new training. """
|
||||||
|
self.should_training_stop = False
|
||||||
|
|
||||||
|
def _new_epoch(self):
|
||||||
|
""" Internal method that resets the variable for a new epoch. """
|
||||||
|
self.should_epoch_stop = False
|
||||||
|
|
||||||
|
def _new_step(self):
|
||||||
|
""" Internal method that resets the variable for a new step. """
|
||||||
|
self.should_save_model = False
|
||||||
|
self.should_evaluate = False
|
||||||
|
self.should_log = False
|
||||||
|
|
||||||
|
|
||||||
|
class TrainerCallback:
|
||||||
|
"""
|
||||||
|
A class for objects that will inspect the state of the training loop at some events and take some decisions. At
|
||||||
|
each of those events the following arguments are available:
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (:class:`~transformers.TrainingArguments`):
|
||||||
|
The training arguments used to instantiate the :class:`~transformers.Trainer`.
|
||||||
|
state (:class:`~transformers.TrainerState`):
|
||||||
|
The current state of the :class:`~transformers.Trainer`.
|
||||||
|
control (:class:`~transformers.TrainerControl`):
|
||||||
|
The object that is returned to the :class:`~transformers.Trainer` and can be used to make some decisions.
|
||||||
|
model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`):
|
||||||
|
The model being trained.
|
||||||
|
optimizer (:obj:`torch.optim.Optimizer`):
|
||||||
|
The optimizer used for the training steps.
|
||||||
|
lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`):
|
||||||
|
The scheduler used for setting the learning rate.
|
||||||
|
train_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`):
|
||||||
|
The current dataloader used for training.
|
||||||
|
eval_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`):
|
||||||
|
The current dataloader used for training.
|
||||||
|
metrics (:obj:`Dict[str, float]`):
|
||||||
|
The metrics computed by the last evaluation phase.
|
||||||
|
|
||||||
|
Those are only accessible in the event :obj:`on_evaluate`.
|
||||||
|
logs (:obj:`Dict[str, float]`):
|
||||||
|
The values to log.
|
||||||
|
|
||||||
|
Those are only accessible in the event :obj:`on_log`.
|
||||||
|
|
||||||
|
The :obj:`control` object is the only one that can be changed by the callback, in which case the event that changes
|
||||||
|
it should return the modified version.
|
||||||
|
|
||||||
|
The argument :obj:`args`, :obj:`state` and :obj:`control` are positionals for all events, all the others are
|
||||||
|
grouped in :obj:`kwargs`. You can unpack the ones you need in the signature of the event using them. As an example,
|
||||||
|
see the code of the simple :class:`~transformer.PrinterCallback`.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
class PrinterCallback(TrainerCallback):
|
||||||
|
|
||||||
|
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||||
|
_ = logs.pop("total_flos", None)
|
||||||
|
if state.is_local_process_zero:
|
||||||
|
print(logs)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
"""
|
||||||
|
Event called at the end of the initialization of the :class:`~transformers.Trainer`.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
"""
|
||||||
|
Event called at the beginning of training.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
"""
|
||||||
|
Event called at the end of training.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
"""
|
||||||
|
Event called at the beginning of an epoch.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
"""
|
||||||
|
Event called at the end of an epoch.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
"""
|
||||||
|
Event called at the beginning of a training step. If using gradient accumulation, one training step might take
|
||||||
|
several inputs.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
"""
|
||||||
|
Event called at the end of a training step. If using gradient accumulation, one training step might take
|
||||||
|
several inputs.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
"""
|
||||||
|
Event called after an evaluation phase.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
"""
|
||||||
|
Event called after a checkpoint save.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
"""
|
||||||
|
Event called after logging the last logs.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
"""
|
||||||
|
Event called after a prediction step.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackHandler(TrainerCallback):
|
||||||
|
""" Internal class that just calls the list of callbacks in order. """
|
||||||
|
|
||||||
|
def __init__(self, callbacks, model, optimizer, lr_scheduler):
|
||||||
|
self.callbacks = []
|
||||||
|
for cb in callbacks:
|
||||||
|
self.add_callback(cb)
|
||||||
|
self.model = model
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.lr_scheduler = lr_scheduler
|
||||||
|
self.train_dataloader = None
|
||||||
|
self.eval_dataloader = None
|
||||||
|
|
||||||
|
if not any(isinstance(cb, DefaultFlowCallback) for cb in self.callbacks):
|
||||||
|
logger.warn(
|
||||||
|
"The Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You\n"
|
||||||
|
+ "should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list of"
|
||||||
|
+ "callbacks is\n:"
|
||||||
|
+ self.callback_list
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_callback(self, callback):
|
||||||
|
cb = callback() if isinstance(callback, type) else callback
|
||||||
|
cb_class = callback if isinstance(callback, type) else callback.__class__
|
||||||
|
if cb_class in [c.__class__ for c in self.callbacks]:
|
||||||
|
logger.warn(
|
||||||
|
f"You are adding a {cb_class} to the callbacks of this Trainer, but there is already one. The current"
|
||||||
|
+ "list of callbacks is\n:"
|
||||||
|
+ self.callback_list
|
||||||
|
)
|
||||||
|
self.callbacks.append(cb)
|
||||||
|
|
||||||
|
def pop_callback(self, callback):
|
||||||
|
if isinstance(callback, type):
|
||||||
|
for cb in self.callbacks:
|
||||||
|
if isinstance(cb, callback):
|
||||||
|
self.callbacks.remove(cb)
|
||||||
|
return cb
|
||||||
|
else:
|
||||||
|
for cb in self.callbacks:
|
||||||
|
if cb == callback:
|
||||||
|
self.callbacks.remove(cb)
|
||||||
|
return cb
|
||||||
|
|
||||||
|
def remove_callback(self, callback):
|
||||||
|
if isinstance(callback, type):
|
||||||
|
for cb in self.callbacks:
|
||||||
|
if isinstance(cb, callback):
|
||||||
|
self.callbacks.remove(cb)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self.callbacks.remove(callback)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def callback_list(self):
|
||||||
|
return "\n".join(self.callbacks)
|
||||||
|
|
||||||
|
def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
||||||
|
return self.call_event("on_init_end", args, state, control)
|
||||||
|
|
||||||
|
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
||||||
|
control.should_training_stop = False
|
||||||
|
return self.call_event("on_train_begin", args, state, control)
|
||||||
|
|
||||||
|
def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
||||||
|
return self.call_event("on_train_end", args, state, control)
|
||||||
|
|
||||||
|
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
||||||
|
control.should_epoch_stop = False
|
||||||
|
return self.call_event("on_epoch_begin", args, state, control)
|
||||||
|
|
||||||
|
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
||||||
|
return self.call_event("on_epoch_end", args, state, control)
|
||||||
|
|
||||||
|
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
||||||
|
control.should_log = False
|
||||||
|
control.should_evaluate = False
|
||||||
|
control.should_save = False
|
||||||
|
return self.call_event("on_step_begin", args, state, control)
|
||||||
|
|
||||||
|
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
||||||
|
return self.call_event("on_step_end", args, state, control)
|
||||||
|
|
||||||
|
def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics):
|
||||||
|
control.should_evaluate = False
|
||||||
|
return self.call_event("on_evaluate", args, state, control, metrics=metrics)
|
||||||
|
|
||||||
|
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
||||||
|
control.should_save = False
|
||||||
|
return self.call_event("on_save", args, state, control)
|
||||||
|
|
||||||
|
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs):
|
||||||
|
control.should_log = False
|
||||||
|
return self.call_event("on_log", args, state, control, logs=logs)
|
||||||
|
|
||||||
|
def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
|
||||||
|
return self.call_event("on_prediction_step", args, state, control)
|
||||||
|
|
||||||
|
def call_event(self, event, args, state, control, **kwargs):
|
||||||
|
for callback in self.callbacks:
|
||||||
|
result = getattr(callback, event)(
|
||||||
|
args,
|
||||||
|
state,
|
||||||
|
control,
|
||||||
|
model=self.model,
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
lr_scheduler=self.lr_scheduler,
|
||||||
|
train_dataloader=self.train_dataloader,
|
||||||
|
eval_dataloader=self.eval_dataloader,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
# A Callback can skip the return of `control` if it doesn't change it.
|
||||||
|
if result is not None:
|
||||||
|
control = result
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultFlowCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
A :class:`~transformers.TrainerCallback` that handles the default flow of the training loop for logs, evaluation
|
||||||
|
and checkpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
# Log
|
||||||
|
if state.global_step == 1 and args.logging_first_step:
|
||||||
|
control.should_log = True
|
||||||
|
if args.logging_steps > 0 and state.global_step % args.logging_steps == 0:
|
||||||
|
control.should_log = True
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
if args.evaluation_strategy == EvaluationStrategy.STEPS and state.global_step % args.eval_steps == 0:
|
||||||
|
control.should_evaluate = True
|
||||||
|
if args.load_best_model_at_end:
|
||||||
|
control.should_save = True
|
||||||
|
|
||||||
|
# Save
|
||||||
|
if not args.load_best_model_at_end and args.save_steps > 0 and state.global_step % args.save_steps == 0:
|
||||||
|
control.should_save = True
|
||||||
|
|
||||||
|
# End training
|
||||||
|
if state.global_step >= state.max_steps:
|
||||||
|
control.should_training_stop = True
|
||||||
|
|
||||||
|
return control
|
||||||
|
|
||||||
|
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
if args.evaluation_strategy == EvaluationStrategy.EPOCH:
|
||||||
|
control.should_evaluate = True
|
||||||
|
if args.load_best_model_at_end:
|
||||||
|
control.should_save = True
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
A :class:`~transformers.TrainerCallback` that displays the progress of training or evaluation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.training_bar = None
|
||||||
|
self.prediction_bar = None
|
||||||
|
|
||||||
|
def on_train_begin(self, args, state, control, **kwargs):
|
||||||
|
if state.is_local_process_zero:
|
||||||
|
self.training_bar = tqdm(total=state.max_steps)
|
||||||
|
|
||||||
|
def on_step_end(self, args, state, control, **kwargs):
|
||||||
|
if state.is_local_process_zero:
|
||||||
|
self.training_bar.update(1)
|
||||||
|
|
||||||
|
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
|
||||||
|
if state.is_local_process_zero:
|
||||||
|
if self.prediction_bar is None:
|
||||||
|
self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None)
|
||||||
|
self.prediction_bar.update(1)
|
||||||
|
|
||||||
|
def on_evaluate(self, args, state, control, **kwargs):
|
||||||
|
if state.is_local_process_zero:
|
||||||
|
self.prediction_bar.close()
|
||||||
|
self.prediction_bar = None
|
||||||
|
|
||||||
|
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||||
|
if state.is_local_process_zero and self.training_bar is not None:
|
||||||
|
_ = logs.pop("total_flos", None)
|
||||||
|
self.training_bar.write(str(logs))
|
||||||
|
|
||||||
|
def on_train_end(self, args, state, control, **kwargs):
|
||||||
|
if state.is_local_process_zero:
|
||||||
|
self.training_bar.close()
|
||||||
|
self.training_bar = None
|
||||||
|
|
||||||
|
|
||||||
|
class PrinterCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
A bare :class:`~transformers.TrainerCallback` that just prints the logs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||||
|
_ = logs.pop("total_flos", None)
|
||||||
|
if state.is_local_process_zero:
|
||||||
|
print(logs)
|
||||||
179
src/transformers/trainer_pt_utils.py
Normal file
179
src/transformers/trainer_pt_utils.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2020-present the HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Torch utilities for the Trainer class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import warnings
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from torch.utils.data.sampler import RandomSampler, Sampler
|
||||||
|
|
||||||
|
from .file_utils import is_torch_tpu_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_tpu_available():
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."
|
||||||
|
|
||||||
|
|
||||||
|
def nested_concat(tensors, new_tensors, dim=0):
|
||||||
|
"Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors."
|
||||||
|
assert type(tensors) == type(
|
||||||
|
new_tensors
|
||||||
|
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
|
||||||
|
if isinstance(tensors, (list, tuple)):
|
||||||
|
return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors))
|
||||||
|
return torch.cat((tensors, new_tensors), dim=dim)
|
||||||
|
|
||||||
|
|
||||||
|
def nested_numpify(tensors):
|
||||||
|
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
|
||||||
|
if isinstance(tensors, (list, tuple)):
|
||||||
|
return type(tensors)(nested_numpify(t) for t in tensors)
|
||||||
|
return tensors.cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
|
def nested_detach(tensors):
|
||||||
|
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
|
||||||
|
if isinstance(tensors, (list, tuple)):
|
||||||
|
return type(tensors)(nested_detach(t) for t in tensors)
|
||||||
|
return tensors.detach()
|
||||||
|
|
||||||
|
|
||||||
|
def nested_xla_mesh_reduce(tensors, name):
|
||||||
|
if is_torch_tpu_available():
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
if isinstance(tensors, (list, tuple)):
|
||||||
|
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
|
||||||
|
return xm.mesh_reduce(name, tensors, torch.cat)
|
||||||
|
else:
|
||||||
|
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
|
||||||
|
|
||||||
|
|
||||||
|
def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> torch.Tensor:
|
||||||
|
try:
|
||||||
|
if isinstance(tensor, (tuple, list)):
|
||||||
|
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
|
||||||
|
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
|
||||||
|
torch.distributed.all_gather(output_tensors, tensor)
|
||||||
|
concat = torch.cat(output_tensors, dim=0)
|
||||||
|
|
||||||
|
# truncate the dummy elements added by SequentialDistributedSampler
|
||||||
|
if num_total_examples is not None:
|
||||||
|
concat = concat[:num_total_examples]
|
||||||
|
return concat
|
||||||
|
except AssertionError:
|
||||||
|
raise AssertionError("Not currently using distributed training")
|
||||||
|
|
||||||
|
|
||||||
|
def distributed_broadcast_scalars(
|
||||||
|
scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
try:
|
||||||
|
tensorized_scalar = torch.tensor(scalars).cuda()
|
||||||
|
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
|
||||||
|
torch.distributed.all_gather(output_tensors, tensorized_scalar)
|
||||||
|
concat = torch.cat(output_tensors, dim=0)
|
||||||
|
|
||||||
|
# truncate the dummy elements added by SequentialDistributedSampler
|
||||||
|
if num_total_examples is not None:
|
||||||
|
concat = concat[:num_total_examples]
|
||||||
|
return concat
|
||||||
|
except AssertionError:
|
||||||
|
raise AssertionError("Not currently using distributed training")
|
||||||
|
|
||||||
|
|
||||||
|
def reissue_pt_warnings(caught_warnings):
|
||||||
|
# Reissue warnings that are not the PT_LR_SCHEDULER_WARNING
|
||||||
|
if len(caught_warnings) > 1:
|
||||||
|
for w in caught_warnings:
|
||||||
|
if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING:
|
||||||
|
warnings.warn(w.message, w.category)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def torch_distributed_zero_first(local_rank: int):
|
||||||
|
"""
|
||||||
|
Decorator to make all processes in distributed training wait for each local_master to do something.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_rank (:obj:`int`): The rank of the local process.
|
||||||
|
"""
|
||||||
|
if local_rank not in [-1, 0]:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
yield
|
||||||
|
if local_rank == 0:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
|
||||||
|
class SequentialDistributedSampler(Sampler):
|
||||||
|
"""
|
||||||
|
Distributed Sampler that subsamples indicies sequentially,
|
||||||
|
making it easier to collate all results at the end.
|
||||||
|
|
||||||
|
Even though we only use this sampler for eval and predict (no training),
|
||||||
|
which means that the model params won't have to be synced (i.e. will not hang
|
||||||
|
for synchronization even if varied number of forward passes), we still add extra
|
||||||
|
samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
|
||||||
|
to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataset, num_replicas=None, rank=None):
|
||||||
|
if num_replicas is None:
|
||||||
|
if not torch.distributed.is_available():
|
||||||
|
raise RuntimeError("Requires distributed package to be available")
|
||||||
|
num_replicas = torch.distributed.get_world_size()
|
||||||
|
if rank is None:
|
||||||
|
if not torch.distributed.is_available():
|
||||||
|
raise RuntimeError("Requires distributed package to be available")
|
||||||
|
rank = torch.distributed.get_rank()
|
||||||
|
self.dataset = dataset
|
||||||
|
self.num_replicas = num_replicas
|
||||||
|
self.rank = rank
|
||||||
|
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
||||||
|
self.total_size = self.num_samples * self.num_replicas
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
indices = list(range(len(self.dataset)))
|
||||||
|
|
||||||
|
# add extra samples to make it evenly divisible
|
||||||
|
indices += indices[: (self.total_size - len(indices))]
|
||||||
|
assert (
|
||||||
|
len(indices) == self.total_size
|
||||||
|
), f"Indices length {len(indices)} and total size {self.total_size} mismatched"
|
||||||
|
|
||||||
|
# subsample
|
||||||
|
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
|
||||||
|
assert (
|
||||||
|
len(indices) == self.num_samples
|
||||||
|
), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
|
||||||
|
|
||||||
|
return iter(indices)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
|
||||||
|
def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset):
|
||||||
|
if xm.xrt_world_size() <= 1:
|
||||||
|
return RandomSampler(dataset)
|
||||||
|
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
||||||
@@ -1,19 +1,30 @@
|
|||||||
import dataclasses
|
# coding=utf-8
|
||||||
import json
|
# Copyright 2020-present the HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Utilities for the Trainer and TFTrainer class. Should be independent from PyTorch and TensorFlow.
|
||||||
|
"""
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
|
||||||
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .file_utils import is_tf_available, is_torch_available, is_torch_tpu_available
|
from .file_utils import is_tf_available, is_torch_available
|
||||||
from .tokenization_utils_base import ExplicitEnum
|
from .tokenization_utils_base import ExplicitEnum
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed: int):
|
def set_seed(seed: int):
|
||||||
"""
|
"""
|
||||||
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf``
|
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf``
|
||||||
@@ -139,144 +150,3 @@ default_hp_space = {
|
|||||||
HPSearchBackend.OPTUNA: default_hp_space_optuna,
|
HPSearchBackend.OPTUNA: default_hp_space_optuna,
|
||||||
HPSearchBackend.RAY: default_hp_space_ray,
|
HPSearchBackend.RAY: default_hp_space_ray,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def nested_concat(tensors, new_tensors, dim=0):
|
|
||||||
"Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors."
|
|
||||||
if is_torch_available():
|
|
||||||
assert type(tensors) == type(
|
|
||||||
new_tensors
|
|
||||||
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
|
|
||||||
if isinstance(tensors, (list, tuple)):
|
|
||||||
return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors))
|
|
||||||
return torch.cat((tensors, new_tensors), dim=dim)
|
|
||||||
else:
|
|
||||||
raise ImportError("Torch must be installed to use `nested_concat`")
|
|
||||||
|
|
||||||
|
|
||||||
def nested_deatch(tensors):
|
|
||||||
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
|
|
||||||
if isinstance(tensors, (list, tuple)):
|
|
||||||
return type(tensors)(nested_detach(t) for t in tensors)
|
|
||||||
return tensors.detach()
|
|
||||||
|
|
||||||
|
|
||||||
def nested_numpify(tensors):
|
|
||||||
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
|
|
||||||
if isinstance(tensors, (list, tuple)):
|
|
||||||
return type(tensors)(nested_numpify(t) for t in tensors)
|
|
||||||
return tensors.cpu().numpy()
|
|
||||||
|
|
||||||
|
|
||||||
def nested_detach(tensors):
|
|
||||||
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
|
|
||||||
if isinstance(tensors, (list, tuple)):
|
|
||||||
return type(tensors)(nested_detach(t) for t in tensors)
|
|
||||||
return tensors.detach()
|
|
||||||
|
|
||||||
|
|
||||||
def nested_xla_mesh_reduce(tensors, name):
|
|
||||||
if is_torch_tpu_available():
|
|
||||||
import torch_xla.core.xla_model as xm
|
|
||||||
|
|
||||||
if isinstance(tensors, (list, tuple)):
|
|
||||||
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
|
|
||||||
return xm.mesh_reduce(name, tensors, torch.cat)
|
|
||||||
else:
|
|
||||||
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
|
|
||||||
|
|
||||||
|
|
||||||
def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> "torch.Tensor":
|
|
||||||
if is_torch_available():
|
|
||||||
try:
|
|
||||||
if isinstance(tensor, (tuple, list)):
|
|
||||||
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
|
|
||||||
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
|
|
||||||
torch.distributed.all_gather(output_tensors, tensor)
|
|
||||||
concat = torch.cat(output_tensors, dim=0)
|
|
||||||
|
|
||||||
# truncate the dummy elements added by SequentialDistributedSampler
|
|
||||||
if num_total_examples is not None:
|
|
||||||
concat = concat[:num_total_examples]
|
|
||||||
return concat
|
|
||||||
except AssertionError:
|
|
||||||
raise AssertionError("Not currently using distributed training")
|
|
||||||
else:
|
|
||||||
raise ImportError("Torch must be installed to use `distributed_concat`")
|
|
||||||
|
|
||||||
|
|
||||||
def distributed_broadcast_scalars(
|
|
||||||
scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
|
|
||||||
) -> "torch.Tensor":
|
|
||||||
if is_torch_available():
|
|
||||||
try:
|
|
||||||
tensorized_scalar = torch.tensor(scalars).cuda()
|
|
||||||
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
|
|
||||||
torch.distributed.all_gather(output_tensors, tensorized_scalar)
|
|
||||||
concat = torch.cat(output_tensors, dim=0)
|
|
||||||
|
|
||||||
# truncate the dummy elements added by SequentialDistributedSampler
|
|
||||||
if num_total_examples is not None:
|
|
||||||
concat = concat[:num_total_examples]
|
|
||||||
return concat
|
|
||||||
except AssertionError:
|
|
||||||
raise AssertionError("Not currently using distributed training")
|
|
||||||
else:
|
|
||||||
raise ImportError("Torch must be installed to use `distributed_broadcast_scalars`")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TrainerState:
|
|
||||||
"""
|
|
||||||
A class containing the `Trainer` inner state that will be saved along the model and optimizer.
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
|
|
||||||
In all this class, one step is to be understood as one update step. When using gradient accumulation, one
|
|
||||||
update step may require several forward and backward passes: if you use :obj:`gradient_accumulation_steps=n`,
|
|
||||||
then one update step requires going throuch `n` batches.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
epoch (:obj:`float`, `optional`):
|
|
||||||
Only set during training, will represent the epoch the training is at (the decimal part being the
|
|
||||||
percentage of the current epoch completed).
|
|
||||||
global_step (:obj:`int`, `optional`, defaults to 0):
|
|
||||||
During training, represents the number of update steps completed.
|
|
||||||
max_steps (:obj:`int`, `optional`, defaults to 0):
|
|
||||||
The number of update steps to do during the current training.
|
|
||||||
total_flos (:obj:`int`, `optional`, defaults to 0):
|
|
||||||
The total number of floating operations done by the model since the beginning of training.
|
|
||||||
log_history (:obj:`List[Dict[str, float]]`, `optional`):
|
|
||||||
The list of logs done since the beginning of training.
|
|
||||||
best_metric (:obj:`float`, `optional`):
|
|
||||||
When tracking the best model, the value of the best metric encountered so far.
|
|
||||||
best_model_checkpoint (:obj:`str`, `optional`):
|
|
||||||
When tracking the best model, the value of the name of the checkpoint for the best model encountered so
|
|
||||||
far.
|
|
||||||
"""
|
|
||||||
|
|
||||||
epoch: Optional[float] = None
|
|
||||||
global_step: int = 0
|
|
||||||
max_steps: int = 0
|
|
||||||
num_train_epochs: int = 0
|
|
||||||
total_flos: int = 0
|
|
||||||
log_history: List[Dict[str, float]] = None
|
|
||||||
best_metric: Optional[float] = None
|
|
||||||
best_model_checkpoint: Optional[str] = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.log_history is None:
|
|
||||||
self.log_history = []
|
|
||||||
|
|
||||||
def save_to_json(self, json_path: str):
|
|
||||||
""" Save the content of this instance in JSON format inside :obj:`json_path`."""
|
|
||||||
json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
|
|
||||||
with open(json_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(json_string)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load_from_json(cls, json_path: str):
|
|
||||||
""" Create an instance from the content of :obj:`json_path`."""
|
|
||||||
with open(json_path, "r", encoding="utf-8") as f:
|
|
||||||
text = f.read()
|
|
||||||
return cls(**json.loads(text))
|
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class TrainingArguments:
|
|||||||
:obj:`"no"`.
|
:obj:`"no"`.
|
||||||
do_predict (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
do_predict (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether to run predictions on the test set or not.
|
Whether to run predictions on the test set or not.
|
||||||
evaluation_strategy(:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`):
|
evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`):
|
||||||
The evaluation strategy to adopt during training. Possible values are:
|
The evaluation strategy to adopt during training. Possible values are:
|
||||||
|
|
||||||
* :obj:`"no"`: No evaluation is done during training.
|
* :obj:`"no"`: No evaluation is done during training.
|
||||||
|
|||||||
@@ -1869,19 +1869,10 @@ class MarianTokenizer:
|
|||||||
requires_pytorch(self)
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
class EvalPrediction:
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_pytorch(self)
|
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_pytorch(self)
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
def set_seed(*args, **kwargs):
|
|
||||||
requires_pytorch(set_seed)
|
|
||||||
|
|
||||||
|
|
||||||
def torch_distributed_zero_first(*args, **kwargs):
|
def torch_distributed_zero_first(*args, **kwargs):
|
||||||
requires_pytorch(torch_distributed_zero_first)
|
requires_pytorch(torch_distributed_zero_first)
|
||||||
|
|||||||
214
tests/test_trainer_callback.py
Normal file
214
tests/test_trainer_callback.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
DefaultFlowCallback,
|
||||||
|
EvaluationStrategy,
|
||||||
|
PrinterCallback,
|
||||||
|
ProgressCallback,
|
||||||
|
Trainer,
|
||||||
|
TrainerCallback,
|
||||||
|
TrainingArguments,
|
||||||
|
is_torch_available,
|
||||||
|
)
|
||||||
|
from transformers.testing_utils import require_torch
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from transformers.trainer import DEFAULT_CALLBACKS
|
||||||
|
|
||||||
|
from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrainerCallback(TrainerCallback):
|
||||||
|
"A callback that registers the events that goes through."
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.events = []
|
||||||
|
|
||||||
|
def on_init_end(self, args, state, control, **kwargs):
|
||||||
|
self.events.append("on_init_end")
|
||||||
|
|
||||||
|
def on_train_begin(self, args, state, control, **kwargs):
|
||||||
|
self.events.append("on_train_begin")
|
||||||
|
|
||||||
|
def on_train_end(self, args, state, control, **kwargs):
|
||||||
|
self.events.append("on_train_end")
|
||||||
|
|
||||||
|
def on_epoch_begin(self, args, state, control, **kwargs):
|
||||||
|
self.events.append("on_epoch_begin")
|
||||||
|
|
||||||
|
def on_epoch_end(self, args, state, control, **kwargs):
|
||||||
|
self.events.append("on_epoch_end")
|
||||||
|
|
||||||
|
def on_step_begin(self, args, state, control, **kwargs):
|
||||||
|
self.events.append("on_step_begin")
|
||||||
|
|
||||||
|
def on_step_end(self, args, state, control, **kwargs):
|
||||||
|
self.events.append("on_step_end")
|
||||||
|
|
||||||
|
def on_evaluate(self, args, state, control, **kwargs):
|
||||||
|
self.events.append("on_evaluate")
|
||||||
|
|
||||||
|
def on_save(self, args, state, control, **kwargs):
|
||||||
|
self.events.append("on_save")
|
||||||
|
|
||||||
|
def on_log(self, args, state, control, **kwargs):
|
||||||
|
self.events.append("on_log")
|
||||||
|
|
||||||
|
def on_prediction_step(self, args, state, control, **kwargs):
|
||||||
|
self.events.append("on_prediction_step")
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class TrainerCallbackTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.output_dir = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
shutil.rmtree(self.output_dir)
|
||||||
|
|
||||||
|
def get_trainer(self, a=0, b=0, train_len=64, eval_len=64, callbacks=None, disable_tqdm=False, **kwargs):
|
||||||
|
# disable_tqdm in TrainingArguments has a flaky default since it depends on the level of logging. We make sure
|
||||||
|
# its set to False since the tests later on depend on its value.
|
||||||
|
train_dataset = RegressionDataset(length=train_len)
|
||||||
|
eval_dataset = RegressionDataset(length=eval_len)
|
||||||
|
config = RegressionModelConfig(a=a, b=b)
|
||||||
|
model = RegressionPreTrainedModel(config)
|
||||||
|
|
||||||
|
args = TrainingArguments(self.output_dir, disable_tqdm=disable_tqdm, **kwargs)
|
||||||
|
return Trainer(
|
||||||
|
model,
|
||||||
|
args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_callbacks_equality(self, cbs1, cbs2):
|
||||||
|
self.assertEqual(len(cbs1), len(cbs2))
|
||||||
|
|
||||||
|
# Order doesn't matter
|
||||||
|
cbs1 = list(sorted(cbs1, key=lambda cb: cb.__name__ if isinstance(cb, type) else cb.__class__.__name__))
|
||||||
|
cbs2 = list(sorted(cbs2, key=lambda cb: cb.__name__ if isinstance(cb, type) else cb.__class__.__name__))
|
||||||
|
|
||||||
|
for cb1, cb2 in zip(cbs1, cbs2):
|
||||||
|
if isinstance(cb1, type) and isinstance(cb2, type):
|
||||||
|
self.assertEqual(cb1, cb2)
|
||||||
|
elif isinstance(cb1, type) and not isinstance(cb2, type):
|
||||||
|
self.assertEqual(cb1, cb2.__class__)
|
||||||
|
elif not isinstance(cb1, type) and isinstance(cb2, type):
|
||||||
|
self.assertEqual(cb1.__class__, cb2)
|
||||||
|
else:
|
||||||
|
self.assertEqual(cb1, cb2)
|
||||||
|
|
||||||
|
def get_expected_events(self, trainer):
|
||||||
|
expected_events = ["on_init_end", "on_train_begin"]
|
||||||
|
step = 0
|
||||||
|
train_dl_len = len(trainer.get_eval_dataloader())
|
||||||
|
evaluation_events = ["on_prediction_step"] * len(trainer.get_eval_dataloader()) + ["on_log", "on_evaluate"]
|
||||||
|
for _ in range(trainer.state.num_train_epochs):
|
||||||
|
expected_events.append("on_epoch_begin")
|
||||||
|
for _ in range(train_dl_len):
|
||||||
|
step += 1
|
||||||
|
expected_events += ["on_step_begin", "on_step_end"]
|
||||||
|
if step % trainer.args.logging_steps == 0:
|
||||||
|
expected_events.append("on_log")
|
||||||
|
if (
|
||||||
|
trainer.args.evaluation_strategy == EvaluationStrategy.STEPS
|
||||||
|
and step % trainer.args.eval_steps == 0
|
||||||
|
):
|
||||||
|
expected_events += evaluation_events.copy()
|
||||||
|
if step % trainer.args.save_steps == 0:
|
||||||
|
expected_events.append("on_save")
|
||||||
|
expected_events.append("on_epoch_end")
|
||||||
|
if trainer.args.evaluation_strategy == EvaluationStrategy.EPOCH:
|
||||||
|
expected_events += evaluation_events.copy()
|
||||||
|
expected_events.append("on_train_end")
|
||||||
|
return expected_events
|
||||||
|
|
||||||
|
def test_init_callback(self):
|
||||||
|
trainer = self.get_trainer()
|
||||||
|
expected_callbacks = DEFAULT_CALLBACKS.copy() + [ProgressCallback]
|
||||||
|
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||||
|
|
||||||
|
# Callbacks passed at init are added to the default callbacks
|
||||||
|
trainer = self.get_trainer(callbacks=[TestTrainerCallback])
|
||||||
|
expected_callbacks.append(TestTrainerCallback)
|
||||||
|
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||||
|
|
||||||
|
# TrainingArguments.disable_tqdm controls if use ProgressCallback or PrinterCallback
|
||||||
|
trainer = self.get_trainer(disable_tqdm=True)
|
||||||
|
expected_callbacks = DEFAULT_CALLBACKS.copy() + [PrinterCallback]
|
||||||
|
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||||
|
|
||||||
|
def test_add_remove_callback(self):
|
||||||
|
expected_callbacks = DEFAULT_CALLBACKS.copy() + [ProgressCallback]
|
||||||
|
trainer = self.get_trainer()
|
||||||
|
|
||||||
|
# We can add, pop, or remove by class name
|
||||||
|
trainer.remove_callback(DefaultFlowCallback)
|
||||||
|
expected_callbacks.remove(DefaultFlowCallback)
|
||||||
|
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||||
|
|
||||||
|
trainer = self.get_trainer()
|
||||||
|
cb = trainer.pop_callback(DefaultFlowCallback)
|
||||||
|
self.assertEqual(cb.__class__, DefaultFlowCallback)
|
||||||
|
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||||
|
|
||||||
|
trainer.add_callback(DefaultFlowCallback)
|
||||||
|
expected_callbacks.insert(0, DefaultFlowCallback)
|
||||||
|
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||||
|
|
||||||
|
# We can also add, pop, or remove by instance
|
||||||
|
trainer = self.get_trainer()
|
||||||
|
cb = trainer.callback_handler.callbacks[0]
|
||||||
|
trainer.remove_callback(cb)
|
||||||
|
expected_callbacks.remove(DefaultFlowCallback)
|
||||||
|
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||||
|
|
||||||
|
trainer = self.get_trainer()
|
||||||
|
cb1 = trainer.callback_handler.callbacks[0]
|
||||||
|
cb2 = trainer.pop_callback(cb1)
|
||||||
|
self.assertEqual(cb1, cb2)
|
||||||
|
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||||
|
|
||||||
|
trainer.add_callback(cb1)
|
||||||
|
expected_callbacks.insert(0, DefaultFlowCallback)
|
||||||
|
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
|
||||||
|
|
||||||
|
def test_event_flow(self):
|
||||||
|
trainer = self.get_trainer(callbacks=[TestTrainerCallback])
|
||||||
|
trainer.train()
|
||||||
|
events = trainer.callback_handler.callbacks[-2].events
|
||||||
|
self.assertEqual(events, self.get_expected_events(trainer))
|
||||||
|
|
||||||
|
# Independent log/save/eval
|
||||||
|
trainer = self.get_trainer(callbacks=[TestTrainerCallback], logging_steps=5)
|
||||||
|
trainer.train()
|
||||||
|
events = trainer.callback_handler.callbacks[-2].events
|
||||||
|
self.assertEqual(events, self.get_expected_events(trainer))
|
||||||
|
|
||||||
|
trainer = self.get_trainer(callbacks=[TestTrainerCallback], save_steps=5)
|
||||||
|
trainer.train()
|
||||||
|
events = trainer.callback_handler.callbacks[-2].events
|
||||||
|
self.assertEqual(events, self.get_expected_events(trainer))
|
||||||
|
|
||||||
|
trainer = self.get_trainer(callbacks=[TestTrainerCallback], eval_steps=5, evaluation_strategy="steps")
|
||||||
|
trainer.train()
|
||||||
|
events = trainer.callback_handler.callbacks[-2].events
|
||||||
|
self.assertEqual(events, self.get_expected_events(trainer))
|
||||||
|
|
||||||
|
trainer = self.get_trainer(callbacks=[TestTrainerCallback], evaluation_strategy="epoch")
|
||||||
|
trainer.train()
|
||||||
|
events = trainer.callback_handler.callbacks[-2].events
|
||||||
|
self.assertEqual(events, self.get_expected_events(trainer))
|
||||||
|
|
||||||
|
# A bit of everything
|
||||||
|
trainer = self.get_trainer(
|
||||||
|
callbacks=[TestTrainerCallback], logging_steps=3, save_steps=10, eval_steps=5, evaluation_strategy="steps"
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
events = trainer.callback_handler.callbacks[-2].events
|
||||||
|
self.assertEqual(events, self.get_expected_events(trainer))
|
||||||
Reference in New Issue
Block a user