Trainer callbacks (#7596)

* Initial callback proposal

* Finish various callbacks

* Post-rebase conflicts

* Fix tests

* Don't use something that's not set

* Documentation

* Remove unwanted print.

* Document all models can work

* Add tests + small fixes

* Update docs/source/internal/trainer_utils.rst

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Address review comments

* Fix TF tests

* Real fix this time

* This one should work

* Fix typo

* Really fix typo

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Sylvain Gugger
2020-10-07 10:50:21 -04:00
committed by GitHub
parent 8fa0c956b3
commit 08ba4b4902
15 changed files with 1340 additions and 483 deletions

View File

@@ -1,10 +1,26 @@
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""
import inspect
import math
import os
import re
import shutil
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@@ -15,8 +31,7 @@ from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
from tqdm.auto import tqdm, trange
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_torch_tpu_available
@@ -34,23 +49,35 @@ from .modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_utils import (
PREFIX_CHECKPOINT_DIR,
BestRun,
EvalPrediction,
EvaluationStrategy,
HPSearchBackend,
PredictionOutput,
from .trainer_callback import (
CallbackHandler,
DefaultFlowCallback,
PrinterCallback,
ProgressCallback,
TrainerCallback,
TrainerControl,
TrainerState,
TrainOutput,
default_compute_objective,
default_hp_space,
)
from .trainer_pt_utils import (
SequentialDistributedSampler,
distributed_broadcast_scalars,
distributed_concat,
get_tpu_sampler,
nested_concat,
nested_detach,
nested_numpify,
nested_xla_mesh_reduce,
reissue_pt_warnings,
)
from .trainer_utils import (
PREFIX_CHECKPOINT_DIR,
BestRun,
EvalPrediction,
HPSearchBackend,
PredictionOutput,
TrainOutput,
default_compute_objective,
default_hp_space,
set_seed,
)
from .training_args import TrainingArguments
@@ -60,7 +87,8 @@ from .utils import logging
_use_native_amp = False
_use_apex = False
PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."
DEFAULT_CALLBACKS = [DefaultFlowCallback]
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"):
@@ -82,16 +110,20 @@ if is_torch_tpu_available():
import torch_xla.distributed.parallel_loader as pl
if is_tensorboard_available():
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
from tensorboardX import SummaryWriter
from .integrations import TensorBoardCallback
DEFAULT_CALLBACKS.append(TensorBoardCallback)
if is_wandb_available():
import wandb
from .integrations import WandbCallback
DEFAULT_CALLBACKS.append(WandbCallback)
if is_comet_available():
import comet_ml
from .integrations import CometCallback
DEFAULT_CALLBACKS.append(CometCallback)
if is_optuna_available():
import optuna
@@ -102,91 +134,20 @@ if is_ray_available():
logger = logging.get_logger(__name__)
def reissue_pt_warnings(caught_warnings):
# Reissue warnings that are not the PT_LR_SCHEDULER_WARNING
if len(caught_warnings) > 1:
for w in caught_warnings:
if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING:
warnings.warn(w.message, w.category)
@contextmanager
def torch_distributed_zero_first(local_rank: int):
"""
Decorator to make all processes in distributed training wait for each local_master to do something.
Args:
local_rank (:obj:`int`): The rank of the local process.
"""
if local_rank not in [-1, 0]:
torch.distributed.barrier()
yield
if local_rank == 0:
torch.distributed.barrier()
class SequentialDistributedSampler(Sampler):
"""
Distributed Sampler that subsamples indicies sequentially,
making it easier to collate all results at the end.
Even though we only use this sampler for eval and predict (no training),
which means that the model params won't have to be synced (i.e. will not hang
for synchronization even if varied number of forward passes), we still add extra
samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
"""
def __init__(self, dataset, num_replicas=None, rank=None):
if num_replicas is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size()
if rank is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
assert (
len(indices) == self.total_size
), f"Indices length {len(indices)} and total size {self.total_size} mismatched"
# subsample
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
assert (
len(indices) == self.num_samples
), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
return iter(indices)
def __len__(self):
return self.num_samples
def get_tpu_sampler(dataset: Dataset):
if xm.xrt_world_size() <= 1:
return RandomSampler(dataset)
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
class Trainer:
"""
Trainer is a simple but feature-complete training and eval loop for PyTorch,
optimized for 🤗 Transformers.
Args:
model (:class:`~transformers.PreTrainedModel`, `optional`):
model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`, `optional`):
The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
.. note::
:class:`~transformers.Trainer` is optimized to work with the :class:`~transformers.PreTrainedModel`
provided by the library. You can still use your own models defined as :obj:`torch.nn.Module` as long as
they work the same way as the 🤗 Transformers models.
args (:class:`~transformers.TrainingArguments`, `optional`):
The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
@@ -210,8 +171,11 @@ class Trainer:
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
The function that will be used to compute metrics at evaluation. Must take a
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
tb_writer (:obj:`SummaryWriter`, `optional`):
Object to write to TensorBoard.
callbacks (List of :obj:`~transformers.TrainerCallback`, `optional`):
A list of callbacks to customize the training loop. Will add those to the list of default callbacks
detailed in :doc:`here <callback>`.
If you want to remove one of the default callbacks used, use the :meth:`Trainer.remove_callback` method.
optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`):
A tuple containing the optimizer and the scheduler to use. Will default to an instance of
:class:`~transformers.AdamW` on your model and a scheduler given by
@@ -222,7 +186,7 @@ class Trainer:
def __init__(
self,
model: PreTrainedModel = None,
model: Union[PreTrainedModel, torch.nn.Module] = None,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
@@ -230,7 +194,7 @@ class Trainer:
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
model_init: Callable[[], PreTrainedModel] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
tb_writer: Optional["SummaryWriter"] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
**kwargs,
):
@@ -259,7 +223,21 @@ class Trainer:
"Passing a `model_init` is incompatible with providing the `optimizers` argument."
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
self.tb_writer = tb_writer
callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
self.callback_handler = CallbackHandler(callbacks, self.model, self.optimizer, self.lr_scheduler)
self.add_callback(PrinterCallback if self.args.disable_tqdm else ProgressCallback)
# Deprecated arguments
if "tb_writer" in kwargs:
warnings.warn(
"Passing `tb_writer` as a keyword argument is deprecated and won't be possible in a "
+ "future version. Use `TensorBoardCallback(tb_writer=...)` instead and pass it to the `callbacks`"
+ "argument",
FutureWarning,
)
tb_writer = kwargs.pop("tb_writer")
self.remove_callback(TensorBoardCallback)
self.add_callback(TensorBoardCallback(tb_writer=tb_writer))
if "prediction_loss_only" in kwargs:
warnings.warn(
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a "
@@ -270,13 +248,6 @@ class Trainer:
self.args.prediction_loss_only = kwargs.pop("prediction_loss_only")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
if tb_writer is None and is_tensorboard_available() and self.is_world_process_zero():
self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
if not is_tensorboard_available():
logger.warning(
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
)
# Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
self._loggers_initialized = False
@@ -304,6 +275,7 @@ class Trainer:
self._remove_unused_columns(self.eval_dataset, description="evaluation")
self.state = TrainerState()
self.control = TrainerControl()
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
# state at each call to self.log.
self._total_flos = None
@@ -317,6 +289,45 @@ class Trainer:
else ["labels"]
)
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
def add_callback(self, callback):
"""
Add a callback to the current list of :class:`~transformer.TrainerCallback`.
Args:
callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
In the first case, will instantiate a member of that class.
"""
self.callback_handler.add_callback(callback)
def pop_callback(self, callback):
"""
Remove a callback from the current list of :class:`~transformer.TrainerCallback` and returns it.
If the callback is not found, returns :obj:`None` (and no error is raised).
Args:
callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
In the first case, will pop the first member of that class found in the list of callbacks.
Returns:
:class:`~transformer.TrainerCallback`: The callback removed, if found.
"""
return self.callback_handler.pop_callback(callback)
def remove_callback(self, callback):
"""
Remove a callback from the current list of :class:`~transformer.TrainerCallback`.
Args:
callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
In the first case, will remove the first member of that class found in the list of callbacks.
"""
self.callback_handler.remove_callback(callback)
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns:
@@ -465,102 +476,12 @@ class Trainer:
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
)
def setup_wandb(self):
"""
Setup the optional Weights & Biases (`wandb`) integration.
One can subclass and override this method to customize the setup if needed. Find more information
`here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
Environment:
WANDB_WATCH:
(Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
or "all" to log gradients and parameters
WANDB_PROJECT:
(Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
WANDB_DISABLED:
(Optional): boolean - defaults to false, set to "true" to disable wandb entirely
"""
if hasattr(self, "_setup_wandb"):
warnings.warn(
"The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
FutureWarning,
)
return self._setup_wandb()
if self.is_world_process_zero():
logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
combined_dict = {**self.args.to_sanitized_dict()}
if isinstance(self.model, PreTrainedModel):
combined_dict = {**self.model.config.to_dict(), **combined_dict}
wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
)
# keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
wandb.watch(
self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
)
def setup_comet(self):
"""
Setup the optional Comet.ml integration.
Environment:
COMET_MODE:
(Optional): str - "OFFLINE", "ONLINE", or "DISABLED"
COMET_PROJECT_NAME:
(Optional): str - Comet.ml project name for experiments
COMET_OFFLINE_DIRECTORY:
(Optional): str - folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE"
For a number of configurable items in the environment,
see `here <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__
"""
if self.is_world_master():
comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
experiment = None
if comet_mode == "ONLINE":
experiment = comet_ml.Experiment(**args)
logger.info("Automatic Comet.ml online logging enabled")
elif comet_mode == "OFFLINE":
args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
experiment = comet_ml.OfflineExperiment(**args)
logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
if experiment is not None:
experiment._set_model_graph(self.model, framework="transformers")
experiment._log_parameters(self.args, prefix="args/", framework="transformers")
if isinstance(self.model, PreTrainedModel):
experiment._log_parameters(self.model.config, prefix="config/", framework="transformers")
def num_examples(self, dataloader: DataLoader) -> int:
"""
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
"""
return len(dataloader.dataset)
def _setup_loggers(self):
if self._loggers_initialized:
return
if is_wandb_available():
self.setup_wandb()
elif os.environ.get("WANDB_DISABLED") != "true":
logger.info(
"You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
"run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
)
if is_comet_available():
self.setup_comet()
elif os.environ.get("COMET_MODE") != "DISABLED":
logger.info(
"To use comet_ml logging, run `pip/conda install comet_ml` "
"see https://www.comet.ml/docs/python-sdk/huggingface/"
)
self._loggers_initialized = True
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
""" HP search setup code """
if self.hp_search_backend is None or trial is None:
@@ -661,7 +582,7 @@ class Trainer:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
reissue_pt_warnings(caught_warnings)
# Moxed precision training with apex (torch < 1.6)
# Mixed precision training with apex (torch < 1.6)
model = self.model
if self.args.fp16 and _use_apex:
if not is_apex_available():
@@ -687,10 +608,6 @@ class Trainer:
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
if self.tb_writer is not None:
self.tb_writer.add_text("args", self.args.to_json_string())
self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
# Train!
if is_torch_tpu_available():
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
@@ -723,17 +640,25 @@ class Trainer:
logger.info(" Continuing training from global step %d", self.state.global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
# Update the references
self.callback_handler.model = self.model
self.callback_handler.optimizer = self.optimizer
self.callback_handler.lr_scheduler = self.lr_scheduler
self.callback_handler.train_dataloader = train_dataloader
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
# to set this after the load.
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()
tr_loss = torch.tensor(0.0).to(self.args.device)
self._logging_loss_scalar = 0
self._total_flos = self.state.total_flos
logging_loss_scalar = 0.0
model.zero_grad()
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
train_pbar = trange(epochs_trained, num_train_epochs, desc="Epoch", disable=disable_tqdm)
self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control)
for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)
@@ -750,15 +675,18 @@ class Trainer:
if self.args.past_index >= 0:
self._past = None
epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm)
self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)
for step, inputs in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
epoch_pbar.update(1)
continue
if (step + 1) % self.args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)
tr_loss += self.training_step(model, inputs)
self._total_flos += self.floating_point_ops(inputs)
@@ -787,50 +715,15 @@ class Trainer:
model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / len(epoch_iterator)
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
if (self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0) or (
self.state.global_step == 1 and self.args.logging_first_step
):
logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item()
logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps
# backward compatibility for pytorch schedulers
logs["learning_rate"] = (
self.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else self.lr_scheduler.get_lr()[0]
)
logging_loss_scalar = tr_loss_scalar
self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
self.log(logs)
if (
self.args.evaluation_strategy == EvaluationStrategy.STEPS
and self.state.global_step % self.args.eval_steps == 0
):
metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)
if self.args.load_best_model_at_end:
self._save_training(model, trial, metrics=metrics)
if (
not self.args.load_best_model_at_end
and self.args.save_steps > 0
and self.state.global_step % self.args.save_steps == 0
):
self._save_training(model, trial)
epoch_pbar.update(1)
if self.state.global_step >= max_steps:
if self.control.should_epoch_stop or self.control.should_training_stop:
break
epoch_pbar.close()
train_pbar.update(1)
if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)
if self.args.load_best_model_at_end:
self._save_training(model, trial, metrics=metrics)
self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
if self.args.tpu_metrics_debug or self.args.debug:
if is_torch_tpu_available():
@@ -841,12 +734,9 @@ class Trainer:
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
"configured. Check your training configuration if this is unexpected."
)
if self.state.global_step >= max_steps:
if self.control.should_training_stop:
break
train_pbar.close()
if self.tb_writer:
self.tb_writer.close()
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
@@ -863,9 +753,36 @@ class Trainer:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
self.model.load_state_dict(state_dict)
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
def _save_training(self, model, trial, metrics=None):
def _maybe_log_save_evalute(self, tr_loss, model, trial, epoch):
if self.control.should_log:
logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item()
logs["loss"] = (tr_loss_scalar - self._logging_loss_scalar) / self.args.logging_steps
# backward compatibility for pytorch schedulers
logs["learning_rate"] = (
self.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else self.lr_scheduler.get_lr()[0]
)
self._logging_loss_scalar = tr_loss_scalar
self.log(logs)
metrics = None
if self.control.should_evaluate:
metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
def _save_checkpoint(self, model, trial, metrics=None):
# In all cases (even distributed/parallel), self.model is always a reference
# to the model we want to save.
if hasattr(model, "module"):
@@ -896,7 +813,7 @@ class Trainer:
reissue_pt_warnings(caught_warnings)
# Determine the new best metric / best model checkpoint
if metrics is not None:
if metrics is not None and self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
@@ -998,7 +915,7 @@ class Trainer:
self.hp_search_backend = None
return best_run
def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
def log(self, logs: Dict[str, float]) -> None:
"""
Log :obj:`logs` on the various objects watching training.
@@ -1007,55 +924,22 @@ class Trainer:
Args:
logs (:obj:`Dict[str, float]`):
The values to log.
iterator (:obj:`tqdm`, `optional`):
A potential tqdm progress bar to write the logs on.
"""
# Set up loggers like W&B or Comet ML
self._setup_loggers()
if hasattr(self, "_log"):
warnings.warn(
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
FutureWarning,
)
return self._log(logs, iterator=iterator)
return self._log(logs)
if self.state.epoch is not None:
logs["epoch"] = self.state.epoch
if self._total_flos is not None:
self.store_flos()
logs["total_flos"] = self.state.total_flos
if self.tb_writer:
for k, v in logs.items():
if isinstance(v, (int, float)):
self.tb_writer.add_scalar(k, v, self.state.global_step)
else:
logger.warning(
"Trainer is attempting to log a value of "
'"%s" of type %s for key "%s" as a scalar. '
"This invocation of Tensorboard's writer.add_scalar() "
"is incorrect so we dropped this attribute.",
v,
type(v),
k,
)
self.tb_writer.flush()
if is_wandb_available():
if self.is_world_process_zero():
wandb.log(logs, step=self.state.global_step)
if is_comet_available():
if self.is_world_process_zero():
experiment = comet_ml.config.get_global_experiment()
if experiment is not None:
experiment._log_metrics(
logs, step=self.state.global_step, epoch=self.state.epoch, framework="transformers"
)
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
output = {**logs, **{"step": self.state.global_step}}
self.state.log_history.append(output)
if iterator is not None:
iterator.write(output)
else:
print(output)
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
"""
@@ -1372,8 +1256,9 @@ class Trainer:
if self.args.past_index >= 0:
self._past = None
disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
self.callback_handler.eval_dataloader = dataloader
for inputs in dataloader:
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
batch_size = inputs[list(inputs.keys())[0]].shape[0]
if loss is not None:
@@ -1382,6 +1267,7 @@ class Trainer:
preds = logits if preds is None else nested_concat(preds, logits, dim=0)
if labels is not None:
label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0)
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop