Upstream (and rename) sortish sampler (#9574)
* Upstream (and rename) sortish sampler * Use proper sampler * Update src/transformers/trainer_pt_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -20,10 +20,11 @@ import math
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
from typing import Iterator, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import RandomSampler, Sampler
|
||||
|
||||
@@ -390,3 +391,136 @@ class LabelSmoother:
|
||||
# Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
|
||||
smoothed_loss = log_probs.mean(dim=-1).sum() / (padding_mask.numel() - padding_mask.long().sum())
|
||||
return (1 - self.epsilon) * model_loss + self.epsilon * smoothed_loss
|
||||
|
||||
|
||||
def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None):
|
||||
"""
|
||||
Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of
|
||||
similar lengths. To do this, the indices are:
|
||||
|
||||
- randomly permuted
|
||||
- grouped in mega-batches of size :obj:`mega_batch_mult * batch_size`
|
||||
- sorted by length in each mega-batch
|
||||
|
||||
The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of
|
||||
maximum length placed first, so that an OOM happens sooner rather than later.
|
||||
"""
|
||||
# Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
|
||||
if mega_batch_mult is None:
|
||||
mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
|
||||
# Just in case, for tiny datasets
|
||||
if mega_batch_mult == 0:
|
||||
mega_batch_mult = 1
|
||||
|
||||
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
||||
indices = torch.randperm(len(lengths), generator=generator)
|
||||
megabatch_size = mega_batch_mult * batch_size
|
||||
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
|
||||
megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
|
||||
|
||||
# The rest is to get the biggest batch first.
|
||||
# Since each megabatch is sorted by descending length, the longest element is the first
|
||||
megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
|
||||
max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
|
||||
# Switch to put the longest element in first position
|
||||
megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0]
|
||||
|
||||
return sum(megabatches, [])
|
||||
|
||||
|
||||
class LengthGroupedSampler(Sampler):
|
||||
r"""
|
||||
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
|
||||
keeping a bit of randomness.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: Dataset, batch_size: int, lengths: Optional[List[int]] = None):
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
if lengths is None:
|
||||
if not isinstance(dataset[0], dict) or "input_ids" not in dataset[0]:
|
||||
raise ValueError(
|
||||
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
||||
"'input_ids' key."
|
||||
)
|
||||
lengths = [len(feature["input_ids"]) for feature in dataset]
|
||||
self.lengths = lengths
|
||||
|
||||
def __len__(self):
|
||||
return len(self.lengths)
|
||||
|
||||
def __iter__(self):
|
||||
indices = get_length_grouped_indices(self.lengths, self.batch_size)
|
||||
return iter(indices)
|
||||
|
||||
|
||||
class DistributedLengthGroupedSampler(DistributedSampler):
|
||||
r"""
|
||||
Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
|
||||
length while keeping a bit of randomness.
|
||||
"""
|
||||
# Copied and adapted from PyTorch DistributedSampler.
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
batch_size: int,
|
||||
num_replicas: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
lengths: Optional[List[int]] = None,
|
||||
):
|
||||
if num_replicas is None:
|
||||
if not torch.distributed.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = torch.distributed.get_world_size()
|
||||
if rank is None:
|
||||
if not torch.distributed.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = torch.distributed.get_rank()
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.drop_last = drop_last
|
||||
# If the dataset length is evenly divisible by # of replicas, then there
|
||||
# is no need to drop any data, since the dataset will be split equally.
|
||||
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
|
||||
# Split to nearest available length that is evenly divisible.
|
||||
# This is to ensure each rank receives the same amount of data when
|
||||
# using this Sampler.
|
||||
self.num_samples = math.ceil((len(self.dataset) - self.num_replicas) / self.num_replicas)
|
||||
else:
|
||||
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.seed = seed
|
||||
|
||||
if lengths is None:
|
||||
if not isinstance(dataset[0], dict) or "input_ids" not in dataset[0]:
|
||||
raise ValueError(
|
||||
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
||||
"'input_ids' key."
|
||||
)
|
||||
lengths = [len(feature["input_ids"]) for feature in dataset]
|
||||
self.lengths = lengths
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
# Deterministically shuffle based on epoch and seed
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.seed + self.epoch)
|
||||
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
|
||||
|
||||
if not self.drop_last:
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[: (self.total_size - len(indices))]
|
||||
else:
|
||||
# remove tail of data to make it evenly divisible.
|
||||
indices = indices[: self.total_size]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
Reference in New Issue
Block a user