Floating-point operations logging in trainer (#6768)
* neFLOs calculation, logging, and reloading (#1) * testing distributed consecutive batches * fixed AttributeError from DataParallel * removed verbosity * rotate with use_mtime=True * removed print * fixed interaction with gradient accumulation * indent formatting * distributed neflo counting * fixed typo * fixed typo * mean distributed losses * exporting log history * moved a few functions * floating_point_ops clarification for transformers with parameter-reuse * code quality * double import * made flo estimation more task-agnostic * only logging flos if computed * code quality * unused import * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Sylvain review * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * black Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -17,8 +17,9 @@
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor, device, dtype, nn
|
||||
@@ -45,7 +46,6 @@ from .utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
from torch.nn import Identity
|
||||
except ImportError:
|
||||
@@ -91,20 +91,6 @@ class ModuleUtilsMixin:
|
||||
A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin.
|
||||
"""
|
||||
|
||||
def num_parameters(self, only_trainable: bool = False) -> int:
|
||||
"""
|
||||
Get the number of (optionally, trainable) parameters in the model.
|
||||
|
||||
Args:
|
||||
only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to return only the number of trainable parameters
|
||||
|
||||
Returns:
|
||||
:obj:`int`: The number of parameters.
|
||||
"""
|
||||
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
|
||||
return sum(p.numel() for p in params)
|
||||
|
||||
@staticmethod
|
||||
def _hook_rss_memory_pre_forward(module, *args, **kwargs):
|
||||
try:
|
||||
@@ -307,9 +293,77 @@ class ModuleUtilsMixin:
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
|
||||
head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility
|
||||
head_mask = head_mask.to(dtype=self.dtype) # switch to float if need + fp16 compatibility
|
||||
return head_mask
|
||||
|
||||
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
||||
"""
|
||||
Get number of (optionally, trainable or non-embeddings) parameters in the module.
|
||||
|
||||
Args:
|
||||
only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to return only the number of trainable parameters
|
||||
|
||||
exclude_embeddings (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to return only the number of non-embeddings parameters
|
||||
|
||||
Returns:
|
||||
:obj:`int`: The number of parameters.
|
||||
"""
|
||||
|
||||
def parameter_filter(x):
|
||||
return (x.requires_grad or not only_trainable) and not (
|
||||
isinstance(x, torch.nn.Embedding) and exclude_embeddings
|
||||
)
|
||||
|
||||
params = filter(parameter_filter, self.parameters()) if only_trainable else self.parameters()
|
||||
return sum(p.numel() for p in params)
|
||||
|
||||
def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int:
|
||||
"""
|
||||
Helper function to estimate the total number of tokens from the model inputs.
|
||||
|
||||
Args:
|
||||
inputs (:obj:`dict`): The model inputs.
|
||||
|
||||
Returns:
|
||||
:obj:`int`: The total number of tokens.
|
||||
"""
|
||||
token_inputs = [tensor for key, tensor in input_dict.items() if "input" in key]
|
||||
if token_inputs:
|
||||
return sum([token_input.numel() for token_input in token_inputs])
|
||||
else:
|
||||
warnings.warn(
|
||||
"Could not estimate the number of tokens of the input, floating-point operations will not be computed"
|
||||
)
|
||||
return 0
|
||||
|
||||
def floating_point_ops(
|
||||
self, input_dict: Dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True
|
||||
) -> int:
|
||||
"""
|
||||
Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a
|
||||
batch with this transformer model. Default approximation neglects the quadratic dependency on the number of
|
||||
tokens (valid if :obj:`12 * d_model << sequence_length`) as laid out in `this paper <https://arxiv.org/pdf/2001.08361.pdf>`__ section
|
||||
2.1. Should be overriden for transformers with parameter re-use e.g. Albert or Universal Transformers, or
|
||||
if doing long-range modeling with very high sequence lengths.
|
||||
|
||||
Args:
|
||||
batch_size (:obj:`int`):
|
||||
The batch size for the forward pass.
|
||||
|
||||
sequence_length (:obj:`int`):
|
||||
The number of tokens in each line of the batch.
|
||||
|
||||
exclude_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to count embedding and softmax operations.
|
||||
|
||||
Returns:
|
||||
:obj:`int`: The number of floating-point operations.
|
||||
"""
|
||||
|
||||
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
|
||||
|
||||
|
||||
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
r"""
|
||||
|
||||
Reference in New Issue
Block a user