Upstream (and rename) sortish sampler (#9574)

* Upstream (and rename) sortish sampler

* Use proper sampler

* Update src/transformers/trainer_pt_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Sylvain Gugger
2021-01-14 10:38:14 -05:00
committed by GitHub
parent 3f40070c88
commit 329fe2746a
5 changed files with 202 additions and 12 deletions

View File

@@ -20,10 +20,11 @@ import math
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from typing import List, Optional, Union
from typing import Iterator, List, Optional, Union
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler
@@ -390,3 +391,136 @@ class LabelSmoother:
# Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
smoothed_loss = log_probs.mean(dim=-1).sum() / (padding_mask.numel() - padding_mask.long().sum())
return (1 - self.epsilon) * model_loss + self.epsilon * smoothed_loss
def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None):
"""
Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of
similar lengths. To do this, the indices are:
- randomly permuted
- grouped in mega-batches of size :obj:`mega_batch_mult * batch_size`
- sorted by length in each mega-batch
The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of
maximum length placed first, so that an OOM happens sooner rather than later.
"""
# Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
if mega_batch_mult is None:
mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
# Just in case, for tiny datasets
if mega_batch_mult == 0:
mega_batch_mult = 1
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
indices = torch.randperm(len(lengths), generator=generator)
megabatch_size = mega_batch_mult * batch_size
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
# The rest is to get the biggest batch first.
# Since each megabatch is sorted by descending length, the longest element is the first
megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
# Switch to put the longest element in first position
megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0]
return sum(megabatches, [])
class LengthGroupedSampler(Sampler):
r"""
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
keeping a bit of randomness.
"""
def __init__(self, dataset: Dataset, batch_size: int, lengths: Optional[List[int]] = None):
self.dataset = dataset
self.batch_size = batch_size
if lengths is None:
if not isinstance(dataset[0], dict) or "input_ids" not in dataset[0]:
raise ValueError(
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
"'input_ids' key."
)
lengths = [len(feature["input_ids"]) for feature in dataset]
self.lengths = lengths
def __len__(self):
return len(self.lengths)
def __iter__(self):
indices = get_length_grouped_indices(self.lengths, self.batch_size)
return iter(indices)
class DistributedLengthGroupedSampler(DistributedSampler):
r"""
Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
length while keeping a bit of randomness.
"""
# Copied and adapted from PyTorch DistributedSampler.
def __init__(
self,
dataset: Dataset,
batch_size: int,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
seed: int = 0,
drop_last: bool = False,
lengths: Optional[List[int]] = None,
):
if num_replicas is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size()
if rank is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank()
self.dataset = dataset
self.batch_size = batch_size
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil((len(self.dataset) - self.num_replicas) / self.num_replicas)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
self.seed = seed
if lengths is None:
if not isinstance(dataset[0], dict) or "input_ids" not in dataset[0]:
raise ValueError(
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
"'input_ids' key."
)
lengths = [len(feature["input_ids"]) for feature in dataset]
self.lengths = lengths
def __iter__(self) -> Iterator:
# Deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
if not self.drop_last:
# add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
else:
# remove tail of data to make it evenly divisible.
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)