Add predict step accumulation (#7767)

* Add eval_accumulation_step and clean distributed eval

* Add TPU test

* Add TPU stuff

* Fix arg name

* Fix Seq2SeqTrainer

* Fix total_size

* Update src/transformers/trainer_pt_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Doc and add test to TPU

* Add unit test

* Adapt name

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Sylvain Gugger
2020-10-14 11:41:45 -04:00
committed by GitHub
parent 8feb0cc967
commit a1d1b332d0
10 changed files with 413 additions and 47 deletions

View File

@@ -21,11 +21,13 @@ import warnings
from contextlib import contextmanager
from typing import List, Optional, Union
import numpy as np
import torch
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler
from .file_utils import is_torch_tpu_available
from .utils import logging
if is_torch_tpu_available():
@@ -33,6 +35,8 @@ if is_torch_tpu_available():
PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."
logger = logging.get_logger(__name__)
def nested_concat(tensors, new_tensors, dim=0):
"Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors."
@@ -41,7 +45,12 @@ def nested_concat(tensors, new_tensors, dim=0):
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors))
return torch.cat((tensors, new_tensors), dim=dim)
elif isinstance(tensors, torch.Tensor):
return torch.cat((tensors, new_tensors), dim=dim)
elif isinstance(tensors, np.ndarray):
return np.concatenate((tensors, new_tensors), axis=dim)
else:
raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}")
def nested_numpify(tensors):
@@ -177,3 +186,112 @@ def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset):
if xm.xrt_world_size() <= 1:
return RandomSampler(dataset)
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
def nested_new_like(arrays, num_samples):
""" Create the same nested structure as `arrays` with a first dimension always at `num_samples`."""
if isinstance(arrays, (list, tuple)):
return type(arrays)(nested_new_like(x, num_samples) for x in arrays)
return np.zeros((num_samples, *arrays.shape[1:]), dtype=arrays.dtype)
def nested_truncate(tensors, limit):
"Truncate `tensors` at `limit` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_truncate(t, limit) for t in tensors)
return tensors[:limit]
class DistributedTensorGatherer:
"""
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU
by chunks.
If our dataset has 16 samples with a batch size of 2 on 3 processes and we gather then transfer on
CPU at every step, our sampler will generate the following indices:
:obj:`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1]`
to get something of size a multiple of 3 (so that each process gets the same dataset length). Then
process 0, 1 and 2 will be responsible of making predictions for the following samples:
- P0: :obj:`[0, 1, 2, 3, 4, 5]`
- P1: :obj:`[6, 7, 8, 9, 10, 11]`
- P2: :obj:`[12, 13, 14, 15, 0, 1]`
The first batch treated on each process will be
- P0: :obj:`[0, 1]`
- P1: :obj:`[6, 7]`
- P2: :obj:`[12, 13]`
So if we gather at the end of the first batch, we will get a tensor (nested list/tuple of tensor)
corresponding to the following indices:
:obj:`[0, 1, 6, 7, 12, 13]`
If we directly concatenate our results without taking any precautions, the user will then get
the predictions for the indices in this order at the end of the prediction loop:
:obj:`[0, 1, 6, 7, 12, 13, 2, 3, 8, 9, 14, 15, 4, 5, 10, 11, 0, 1]`
For some reason, that's not going to roll their boat. This class is there to solve that problem.
Args:
world_size (:obj:`int`):
The number of processes used in the distributed training.
num_samples (:obj:`int`):
The number of samples in our dataset.
make_multiple_of (:obj:`int`, `optional`):
If passed, the class assumes the datasets passed to each process are made to be a multiple of this argument
(by adding samples).
"""
def __init__(self, world_size, num_samples, make_multiple_of=None):
self.world_size = world_size
self.num_samples = num_samples
total_size = world_size if make_multiple_of is None else world_size * make_multiple_of
self.total_samples = int(np.ceil(num_samples / total_size)) * total_size
self.process_length = self.total_samples // world_size
self._storage = None
self._offsets = None
def add_arrays(self, arrays):
"""
Add :obj:`arrays` to the internal storage, Will initialize the storage to the full size at the first arrays
passed so that if we're bound to get an OOM, it happens at the beginning.
"""
if arrays is None:
return
if self._storage is None:
self._storage = nested_new_like(arrays, self.total_samples)
self._offsets = list(range(0, self.total_samples, self.process_length))
slice_len = self._nested_set_tensors(self._storage, arrays)
for i in range(self.world_size):
self._offsets[i] += slice_len
def _nested_set_tensors(self, storage, arrays):
if isinstance(arrays, (list, tuple)):
for x, y in zip(storage, arrays):
slice_len = self._nested_set_tensors(x, y)
return slice_len
assert (
arrays.shape[0] % self.world_size == 0
), f"Arrays passed should all have a first dimension multiple of {self.world_size}, found {arrays.shape[0]}."
slice_len = arrays.shape[0] // self.world_size
for i in range(self.world_size):
storage[self._offsets[i] : self._offsets[i] + slice_len] = arrays[i * slice_len : (i + 1) * slice_len]
return slice_len
def finalize(self):
"""
Return the properly gathered arrays and truncate to the number of samples (since the sampler added some extras
to get each process a dataset of the same length).
"""
if self._storage is None:
return
if self._offsets[0] != self.process_length:
logger.warn("Not all data has been set. Are you sure you passed all values?")
return nested_truncate(self._storage, self.num_samples)