Floating-point operations logging in trainer (#6768)
* neFLOs calculation, logging, and reloading (#1) * testing distributed consecutive batches * fixed AttributeError from DataParallel * removed verbosity * rotate with use_mtime=True * removed print * fixed interaction with gradient accumulation * indent formatting * distributed neflo counting * fixed typo * fixed typo * mean distributed losses * exporting log history * moved a few functions * floating_point_ops clarification for transformers with parameter-reuse * code quality * double import * made flo estimation more task-agnostic * only logging flos if computed * code quality * unused import * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Sylvain review * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * black Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -17,8 +17,9 @@
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor, device, dtype, nn
|
||||
@@ -45,7 +46,6 @@ from .utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
from torch.nn import Identity
|
||||
except ImportError:
|
||||
@@ -91,20 +91,6 @@ class ModuleUtilsMixin:
|
||||
A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin.
|
||||
"""
|
||||
|
||||
def num_parameters(self, only_trainable: bool = False) -> int:
|
||||
"""
|
||||
Get the number of (optionally, trainable) parameters in the model.
|
||||
|
||||
Args:
|
||||
only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to return only the number of trainable parameters
|
||||
|
||||
Returns:
|
||||
:obj:`int`: The number of parameters.
|
||||
"""
|
||||
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
|
||||
return sum(p.numel() for p in params)
|
||||
|
||||
@staticmethod
|
||||
def _hook_rss_memory_pre_forward(module, *args, **kwargs):
|
||||
try:
|
||||
@@ -307,9 +293,77 @@ class ModuleUtilsMixin:
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
|
||||
head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility
|
||||
head_mask = head_mask.to(dtype=self.dtype) # switch to float if need + fp16 compatibility
|
||||
return head_mask
|
||||
|
||||
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
||||
"""
|
||||
Get number of (optionally, trainable or non-embeddings) parameters in the module.
|
||||
|
||||
Args:
|
||||
only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to return only the number of trainable parameters
|
||||
|
||||
exclude_embeddings (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to return only the number of non-embeddings parameters
|
||||
|
||||
Returns:
|
||||
:obj:`int`: The number of parameters.
|
||||
"""
|
||||
|
||||
def parameter_filter(x):
|
||||
return (x.requires_grad or not only_trainable) and not (
|
||||
isinstance(x, torch.nn.Embedding) and exclude_embeddings
|
||||
)
|
||||
|
||||
params = filter(parameter_filter, self.parameters()) if only_trainable else self.parameters()
|
||||
return sum(p.numel() for p in params)
|
||||
|
||||
def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int:
|
||||
"""
|
||||
Helper function to estimate the total number of tokens from the model inputs.
|
||||
|
||||
Args:
|
||||
inputs (:obj:`dict`): The model inputs.
|
||||
|
||||
Returns:
|
||||
:obj:`int`: The total number of tokens.
|
||||
"""
|
||||
token_inputs = [tensor for key, tensor in input_dict.items() if "input" in key]
|
||||
if token_inputs:
|
||||
return sum([token_input.numel() for token_input in token_inputs])
|
||||
else:
|
||||
warnings.warn(
|
||||
"Could not estimate the number of tokens of the input, floating-point operations will not be computed"
|
||||
)
|
||||
return 0
|
||||
|
||||
def floating_point_ops(
|
||||
self, input_dict: Dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True
|
||||
) -> int:
|
||||
"""
|
||||
Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a
|
||||
batch with this transformer model. Default approximation neglects the quadratic dependency on the number of
|
||||
tokens (valid if :obj:`12 * d_model << sequence_length`) as laid out in `this paper <https://arxiv.org/pdf/2001.08361.pdf>`__ section
|
||||
2.1. Should be overriden for transformers with parameter re-use e.g. Albert or Universal Transformers, or
|
||||
if doing long-range modeling with very high sequence lengths.
|
||||
|
||||
Args:
|
||||
batch_size (:obj:`int`):
|
||||
The batch size for the forward pass.
|
||||
|
||||
sequence_length (:obj:`int`):
|
||||
The number of tokens in each line of the batch.
|
||||
|
||||
exclude_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to count embedding and softmax operations.
|
||||
|
||||
Returns:
|
||||
:obj:`int`: The number of floating-point operations.
|
||||
"""
|
||||
|
||||
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
|
||||
|
||||
|
||||
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
r"""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import inspect
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
@@ -42,6 +43,8 @@ from .trainer_utils import (
|
||||
TrainOutput,
|
||||
default_compute_objective,
|
||||
default_hp_space,
|
||||
distributed_broadcast_scalars,
|
||||
distributed_concat,
|
||||
set_seed,
|
||||
)
|
||||
from .training_args import TrainingArguments
|
||||
@@ -146,7 +149,7 @@ class SequentialDistributedSampler(Sampler):
|
||||
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
|
||||
assert (
|
||||
len(indices) == self.num_samples
|
||||
), f"Indices length {len(indices)} and and sample number {self.num_samples} mismatched"
|
||||
), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
|
||||
|
||||
return iter(indices)
|
||||
|
||||
@@ -241,6 +244,7 @@ class Trainer:
|
||||
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
||||
)
|
||||
self.tb_writer = tb_writer
|
||||
self.log_history = []
|
||||
if "prediction_loss_only" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",
|
||||
@@ -284,6 +288,7 @@ class Trainer:
|
||||
|
||||
self.global_step = None
|
||||
self.epoch = None
|
||||
self.total_flos = None
|
||||
if self.args.fp16 and _use_native_amp:
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
self.hp_search_backend = None
|
||||
@@ -461,7 +466,11 @@ class Trainer:
|
||||
logger.info(
|
||||
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
|
||||
)
|
||||
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
|
||||
try:
|
||||
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
|
||||
except AttributeError:
|
||||
# in case the model has no config
|
||||
combined_dict = {**self.args.to_sanitized_dict()}
|
||||
wandb.init(
|
||||
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
|
||||
)
|
||||
@@ -663,6 +672,7 @@ class Trainer:
|
||||
|
||||
self.global_step = 0
|
||||
self.epoch = 0
|
||||
self.total_flos = 0
|
||||
epochs_trained = 0
|
||||
steps_trained_in_current_epoch = 0
|
||||
# Check if continuing training from a checkpoint
|
||||
@@ -670,6 +680,8 @@ class Trainer:
|
||||
# set global_step to global_step of last saved checkpoint from model path
|
||||
try:
|
||||
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
|
||||
self.total_flos = getattr(model.config, "total_flos", 0)
|
||||
|
||||
epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
|
||||
steps_trained_in_current_epoch = self.global_step % (
|
||||
len(train_dataloader) // self.args.gradient_accumulation_steps
|
||||
@@ -678,9 +690,11 @@ class Trainer:
|
||||
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
||||
logger.info(" Continuing training from epoch %d", epochs_trained)
|
||||
logger.info(" Continuing training from global step %d", self.global_step)
|
||||
logger.info(" Continuing training from %d non-embedding floating-point operations", self.total_flos)
|
||||
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
||||
except ValueError:
|
||||
self.global_step = 0
|
||||
self.total_flos = 0
|
||||
logger.info(" Starting fine-tuning.")
|
||||
|
||||
tr_loss = torch.tensor(0.0).to(self.args.device)
|
||||
@@ -714,6 +728,7 @@ class Trainer:
|
||||
continue
|
||||
|
||||
tr_loss += self.training_step(model, inputs)
|
||||
self.total_flos += self.floating_point_ops(inputs)
|
||||
|
||||
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
||||
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
||||
@@ -784,7 +799,7 @@ class Trainer:
|
||||
self.save_model(output_dir)
|
||||
|
||||
if self.is_world_process_zero():
|
||||
self._rotate_checkpoints()
|
||||
self._rotate_checkpoints(use_mtime=True)
|
||||
|
||||
if is_torch_tpu_available():
|
||||
xm.rendezvous("saving_optimizer_states")
|
||||
@@ -924,6 +939,13 @@ class Trainer:
|
||||
|
||||
if self.epoch is not None:
|
||||
logs["epoch"] = self.epoch
|
||||
if self.total_flos is not None:
|
||||
if self.args.local_rank != -1:
|
||||
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
|
||||
else:
|
||||
total_flos = self.total_flos
|
||||
if total_flos > 0:
|
||||
logs["total_flos"] = self.total_flos
|
||||
if self.global_step is None:
|
||||
# when logging evaluation metrics without training
|
||||
self.global_step = 0
|
||||
@@ -951,6 +973,8 @@ class Trainer:
|
||||
if experiment is not None:
|
||||
experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers")
|
||||
output = {**logs, **{"step": self.global_step}}
|
||||
if self.is_world_process_zero():
|
||||
self.log_history.append(output)
|
||||
if iterator is not None:
|
||||
iterator.write(output)
|
||||
else:
|
||||
@@ -1089,6 +1113,9 @@ class Trainer:
|
||||
if xm.is_master_ordinal():
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
||||
json.dump(
|
||||
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
|
||||
)
|
||||
|
||||
# Save a trained model and configuration using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
@@ -1096,6 +1123,7 @@ class Trainer:
|
||||
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
|
||||
|
||||
xm.rendezvous("saving_checkpoint")
|
||||
self._store_flos()
|
||||
self.model.save_pretrained(output_dir)
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
@@ -1108,12 +1136,26 @@ class Trainer:
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
if not isinstance(self.model, PreTrainedModel):
|
||||
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
|
||||
self._store_flos()
|
||||
self.model.save_pretrained(output_dir)
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
||||
json.dump(
|
||||
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
|
||||
)
|
||||
|
||||
def _store_flos(self):
|
||||
# Storing the number of floating-point operations that went into the model
|
||||
if self.total_flos is not None:
|
||||
if self.args.local_rank != -1:
|
||||
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
|
||||
else:
|
||||
total_flos = self.total_flos
|
||||
if total_flos > 0:
|
||||
self.model.config.total_flos = total_flos
|
||||
|
||||
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
|
||||
ordering_and_checkpoint_path = []
|
||||
@@ -1245,13 +1287,11 @@ class Trainer:
|
||||
self._past = None
|
||||
|
||||
disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
|
||||
samples_count = 0
|
||||
for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
|
||||
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
|
||||
batch_size = inputs[list(inputs.keys())[0]].shape[0]
|
||||
samples_count += batch_size
|
||||
if loss is not None:
|
||||
eval_losses.append(loss * batch_size)
|
||||
eval_losses.extend([loss] * batch_size)
|
||||
if logits is not None:
|
||||
preds = logits if preds is None else torch.cat((preds, logits), dim=0)
|
||||
if labels is not None:
|
||||
@@ -1264,9 +1304,9 @@ class Trainer:
|
||||
if self.args.local_rank != -1:
|
||||
# In distributed mode, concatenate all results from all nodes:
|
||||
if preds is not None:
|
||||
preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
|
||||
preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
|
||||
if label_ids is not None:
|
||||
label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
|
||||
label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
|
||||
elif is_torch_tpu_available():
|
||||
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
|
||||
if preds is not None:
|
||||
@@ -1289,7 +1329,14 @@ class Trainer:
|
||||
else:
|
||||
metrics = {}
|
||||
if len(eval_losses) > 0:
|
||||
metrics["eval_loss"] = np.sum(eval_losses) / samples_count
|
||||
if self.args.local_rank != -1:
|
||||
metrics["eval_loss"] = (
|
||||
distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader))
|
||||
.mean()
|
||||
.item()
|
||||
)
|
||||
else:
|
||||
metrics["eval_loss"] = np.mean(eval_losses)
|
||||
|
||||
# Prefix all keys with eval_
|
||||
for key in list(metrics.keys()):
|
||||
@@ -1298,18 +1345,6 @@ class Trainer:
|
||||
|
||||
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
|
||||
|
||||
def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor:
|
||||
assert self.args.local_rank != -1
|
||||
|
||||
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(output_tensors, tensor)
|
||||
|
||||
concat = torch.cat(output_tensors, dim=0)
|
||||
|
||||
# truncate the dummy elements added by SequentialDistributedSampler
|
||||
output = concat[:num_total_examples]
|
||||
return output
|
||||
|
||||
def prediction_step(
|
||||
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
|
||||
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
@@ -1355,3 +1390,32 @@ class Trainer:
|
||||
if labels is not None:
|
||||
labels = labels.detach()
|
||||
return (loss, logits.detach(), labels)
|
||||
|
||||
def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
|
||||
"""
|
||||
For models that inherit from :class:`~transformers.PretrainedModel`, uses
|
||||
that method to compute the number of floating point operations for every backward + forward pass. If using
|
||||
another model, either implement such a method in the model or subclass and override this method.
|
||||
|
||||
Args:
|
||||
model (:obj:`nn.Module`):
|
||||
The model to evaluate.
|
||||
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
|
||||
Returns:
|
||||
:obj:`int`: The number of floating-point operations.
|
||||
"""
|
||||
|
||||
if isinstance(self.model, torch.nn.DataParallel) or isinstance(
|
||||
self.model, torch.nn.parallel.DistributedDataParallel
|
||||
):
|
||||
model = self.model.module
|
||||
else:
|
||||
model = self.model
|
||||
|
||||
if hasattr(model, "floating_point_ops"):
|
||||
return model.floating_point_ops(inputs)
|
||||
|
||||
else:
|
||||
return 0
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import random
|
||||
from typing import Any, Dict, NamedTuple, Optional
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .file_utils import is_tf_available, is_torch_available
|
||||
from .tokenization_utils_base import ExplicitEnum
|
||||
@@ -126,3 +127,32 @@ default_hp_space = {
|
||||
HPSearchBackend.OPTUNA: default_hp_space_optuna,
|
||||
HPSearchBackend.RAY: default_hp_space_ray,
|
||||
}
|
||||
|
||||
|
||||
def distributed_concat(self, tensor: torch.Tensor, num_total_examples: Optional[int] = None) -> torch.Tensor:
|
||||
assert self.args.local_rank != -1
|
||||
|
||||
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(output_tensors, tensor)
|
||||
concat = torch.cat(output_tensors, dim=0)
|
||||
|
||||
# truncate the dummy elements added by SequentialDistributedSampler
|
||||
if num_total_examples is not None:
|
||||
concat = concat[:num_total_examples]
|
||||
return concat
|
||||
|
||||
|
||||
def distributed_broadcast_scalars(
|
||||
self, scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
assert self.args.local_rank != -1
|
||||
|
||||
tensorized_scalar = torch.Tensor(scalars).cuda()
|
||||
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(output_tensors, tensorized_scalar)
|
||||
concat = torch.cat(output_tensors, dim=0)
|
||||
|
||||
# truncate the dummy elements added by SequentialDistributedSampler
|
||||
if num_total_examples is not None:
|
||||
concat = concat[:num_total_examples]
|
||||
return concat
|
||||
|
||||
Reference in New Issue
Block a user