Floating-point operations logging in trainer (#6768)

* neFLOs calculation, logging, and reloading (#1)

* testing distributed consecutive batches

* fixed AttributeError from DataParallel

* removed verbosity

* rotate with use_mtime=True

* removed print

* fixed interaction with gradient accumulation

* indent formatting

* distributed neflo counting

* fixed typo

* fixed typo

* mean distributed losses

* exporting log history

* moved a few functions

* floating_point_ops clarification for transformers with parameter-reuse

* code quality

* double import

* made flo estimation more task-agnostic

* only logging flos if computed

* code quality

* unused import

* Update src/transformers/trainer.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Sylvain review

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* black

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Teven
2020-09-08 16:00:56 +02:00
committed by GitHub
parent d155b38d6e
commit 01d340adfa
3 changed files with 187 additions and 39 deletions

View File

@@ -17,8 +17,9 @@
import inspect
import os
import re
import warnings
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch
from torch import Tensor, device, dtype, nn
@@ -45,7 +46,6 @@ from .utils import logging
logger = logging.get_logger(__name__)
try:
from torch.nn import Identity
except ImportError:
@@ -91,20 +91,6 @@ class ModuleUtilsMixin:
A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin.
"""
def num_parameters(self, only_trainable: bool = False) -> int:
"""
Get the number of (optionally, trainable) parameters in the model.
Args:
only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return only the number of trainable parameters
Returns:
:obj:`int`: The number of parameters.
"""
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
return sum(p.numel() for p in params)
@staticmethod
def _hook_rss_memory_pre_forward(module, *args, **kwargs):
try:
@@ -307,9 +293,77 @@ class ModuleUtilsMixin:
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility
head_mask = head_mask.to(dtype=self.dtype) # switch to float if need + fp16 compatibility
return head_mask
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
"""
Get number of (optionally, trainable or non-embeddings) parameters in the module.
Args:
only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return only the number of trainable parameters
exclude_embeddings (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return only the number of non-embeddings parameters
Returns:
:obj:`int`: The number of parameters.
"""
def parameter_filter(x):
return (x.requires_grad or not only_trainable) and not (
isinstance(x, torch.nn.Embedding) and exclude_embeddings
)
params = filter(parameter_filter, self.parameters()) if only_trainable else self.parameters()
return sum(p.numel() for p in params)
def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int:
"""
Helper function to estimate the total number of tokens from the model inputs.
Args:
inputs (:obj:`dict`): The model inputs.
Returns:
:obj:`int`: The total number of tokens.
"""
token_inputs = [tensor for key, tensor in input_dict.items() if "input" in key]
if token_inputs:
return sum([token_input.numel() for token_input in token_inputs])
else:
warnings.warn(
"Could not estimate the number of tokens of the input, floating-point operations will not be computed"
)
return 0
def floating_point_ops(
self, input_dict: Dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True
) -> int:
"""
Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a
batch with this transformer model. Default approximation neglects the quadratic dependency on the number of
tokens (valid if :obj:`12 * d_model << sequence_length`) as laid out in `this paper <https://arxiv.org/pdf/2001.08361.pdf>`__ section
2.1. Should be overriden for transformers with parameter re-use e.g. Albert or Universal Transformers, or
if doing long-range modeling with very high sequence lengths.
Args:
batch_size (:obj:`int`):
The batch size for the forward pass.
sequence_length (:obj:`int`):
The number of tokens in each line of the batch.
exclude_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to count embedding and softmax operations.
Returns:
:obj:`int`: The number of floating-point operations.
"""
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
r"""