From 329fe2746a4931614438a74effeb3d77005e4c53 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 14 Jan 2021 10:38:14 -0500 Subject: [PATCH] Upstream (and rename) sortish sampler (#9574) * Upstream (and rename) sortish sampler * Use proper sampler * Update src/transformers/trainer_pt_utils.py Co-authored-by: Lysandre Debut Co-authored-by: Lysandre Debut --- examples/seq2seq/test_finetune_trainer.py | 2 +- src/transformers/trainer.py | 37 ++++-- src/transformers/trainer_pt_utils.py | 136 +++++++++++++++++++++- src/transformers/training_args.py | 7 ++ tests/test_trainer_utils.py | 32 ++++- 5 files changed, 202 insertions(+), 12 deletions(-) diff --git a/examples/seq2seq/test_finetune_trainer.py b/examples/seq2seq/test_finetune_trainer.py index 03f0e3183c..0affe52902 100644 --- a/examples/seq2seq/test_finetune_trainer.py +++ b/examples/seq2seq/test_finetune_trainer.py @@ -169,7 +169,7 @@ class TestFinetuneTrainer(TestCasePlus): --logging_steps 0 --save_steps {str(eval_steps)} --eval_steps {str(eval_steps)} - --sortish_sampler + --group_by_length --label_smoothing_factor 0.1 --adafactor --task translation diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3138cf1da2..a58119f88f 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -70,12 +70,13 @@ from .trainer_callback import ( TrainerState, ) from .trainer_pt_utils import ( + DistributedLengthGroupedSampler, DistributedTensorGatherer, LabelSmoother, + LengthGroupedSampler, SequentialDistributedSampler, distributed_broadcast_scalars, distributed_concat, - get_tpu_sampler, nested_concat, nested_detach, nested_numpify, @@ -94,7 +95,7 @@ from .trainer_utils import ( set_seed, speed_metrics, ) -from .training_args import TrainingArguments +from .training_args import ParallelMode, TrainingArguments from .utils import logging @@ -448,14 +449,32 @@ class Trainer: self.train_dataset, collections.abc.Sized ): return None - elif is_torch_tpu_available(): - return get_tpu_sampler(self.train_dataset) + + # Gather the number of processes and this process index. + if self.args.parallel_mode == ParallelMode.TPU: + num_processes = xm.xrt_world_size() + process_index = xm.get_ordinal() + elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: + num_processes = torch.distributed.get_world_size() + process_index = torch.distributed.get_rank() else: - return ( - RandomSampler(self.train_dataset) - if self.args.local_rank == -1 - else DistributedSampler(self.train_dataset) - ) + num_processes = 1 + process_index = 0 + + # Build the sampler. + if self.args.group_by_length: + if num_processes <= 1: + return LengthGroupedSampler(self.train_dataset, self.args.train_batch_size) + else: + return DistributedLengthGroupedSampler( + self.train_dataset, self.args.train_batch_size, num_replicas=num_processes, rank=process_index + ) + + else: + if num_processes <= 1: + return RandomSampler(self.train_dataset) + else: + return DistributedSampler(self.train_dataset, num_replicas=num_processes, rank=process_index) def get_train_dataloader(self) -> DataLoader: """ diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 89d51f5c4c..850f8f8415 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -20,10 +20,11 @@ import math import warnings from contextlib import contextmanager from dataclasses import dataclass -from typing import List, Optional, Union +from typing import Iterator, List, Optional, Union import numpy as np import torch +from torch.utils.data.dataset import Dataset from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler, Sampler @@ -390,3 +391,136 @@ class LabelSmoother: # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded): smoothed_loss = log_probs.mean(dim=-1).sum() / (padding_mask.numel() - padding_mask.long().sum()) return (1 - self.epsilon) * model_loss + self.epsilon * smoothed_loss + + +def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None): + """ + Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of + similar lengths. To do this, the indices are: + + - randomly permuted + - grouped in mega-batches of size :obj:`mega_batch_mult * batch_size` + - sorted by length in each mega-batch + + The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of + maximum length placed first, so that an OOM happens sooner rather than later. + """ + # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller. + if mega_batch_mult is None: + mega_batch_mult = min(len(lengths) // (batch_size * 4), 50) + # Just in case, for tiny datasets + if mega_batch_mult == 0: + mega_batch_mult = 1 + + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + indices = torch.randperm(len(lengths), generator=generator) + megabatch_size = mega_batch_mult * batch_size + megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] + megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches] + + # The rest is to get the biggest batch first. + # Since each megabatch is sorted by descending length, the longest element is the first + megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches] + max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item() + # Switch to put the longest element in first position + megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0] + + return sum(megabatches, []) + + +class LengthGroupedSampler(Sampler): + r""" + Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while + keeping a bit of randomness. + """ + + def __init__(self, dataset: Dataset, batch_size: int, lengths: Optional[List[int]] = None): + self.dataset = dataset + self.batch_size = batch_size + if lengths is None: + if not isinstance(dataset[0], dict) or "input_ids" not in dataset[0]: + raise ValueError( + "Can only automatically infer lengths for datasets whose items are dictionaries with an " + "'input_ids' key." + ) + lengths = [len(feature["input_ids"]) for feature in dataset] + self.lengths = lengths + + def __len__(self): + return len(self.lengths) + + def __iter__(self): + indices = get_length_grouped_indices(self.lengths, self.batch_size) + return iter(indices) + + +class DistributedLengthGroupedSampler(DistributedSampler): + r""" + Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same + length while keeping a bit of randomness. + """ + # Copied and adapted from PyTorch DistributedSampler. + def __init__( + self, + dataset: Dataset, + batch_size: int, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + seed: int = 0, + drop_last: bool = False, + lengths: Optional[List[int]] = None, + ): + if num_replicas is None: + if not torch.distributed.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = torch.distributed.get_world_size() + if rank is None: + if not torch.distributed.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = torch.distributed.get_rank() + self.dataset = dataset + self.batch_size = batch_size + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.drop_last = drop_last + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.dataset) % self.num_replicas != 0: + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil((len(self.dataset) - self.num_replicas) / self.num_replicas) + else: + self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) + self.total_size = self.num_samples * self.num_replicas + self.seed = seed + + if lengths is None: + if not isinstance(dataset[0], dict) or "input_ids" not in dataset[0]: + raise ValueError( + "Can only automatically infer lengths for datasets whose items are dictionaries with an " + "'input_ids' key." + ) + lengths = [len(feature["input_ids"]) for feature in dataset] + self.lengths = lengths + + def __iter__(self) -> Iterator: + # Deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g) + + if not self.drop_last: + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size - len(indices))] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index a85e47d9ea..abef9e35c6 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -227,6 +227,9 @@ class TrainingArguments: adafactor (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to use the :class:`~transformers.Adafactor` optimizer instead of :class:`~transformers.AdamW`. + group_by_length (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to group together samples of roughly the same legnth in the training dataset (to minimize + padding applied and be more efficient). Only useful if applying dynamic padding. """ output_dir: str = field( @@ -405,6 +408,10 @@ class TrainingArguments: default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."} ) adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace Adam by Adafactor."}) + group_by_length: bool = field( + default=False, + metadata={"help": "Whether or not to group samples of roughly the same length together when batching."}, + ) _n_gpu: int = field(init=False, repr=False, default=-1) def __post_init__(self): diff --git a/tests/test_trainer_utils.py b/tests/test_trainer_utils.py index a62e95ac6b..f375ca5367 100644 --- a/tests/test_trainer_utils.py +++ b/tests/test_trainer_utils.py @@ -25,7 +25,12 @@ if is_torch_available(): import torch from transformers.modeling_outputs import SequenceClassifierOutput - from transformers.trainer_pt_utils import DistributedTensorGatherer, LabelSmoother + from transformers.trainer_pt_utils import ( + DistributedLengthGroupedSampler, + DistributedTensorGatherer, + LabelSmoother, + LengthGroupedSampler, + ) @require_torch @@ -87,3 +92,28 @@ class TrainerUtilsTest(unittest.TestCase): log_probs[2, 3] = 0.0 expected_loss = (1 - epsilon) * loss + epsilon * log_probs.sum() / (num_labels * 17) self.assertTrue(torch.allclose(label_smoothed_loss, expected_loss)) + + def test_group_by_length(self): + # Get some inputs of random lengths + lengths = torch.randint(0, 25, (100,)).tolist() + # Put one bigger than the others to check it ends up in first position + lengths[32] = 50 + + indices = list(LengthGroupedSampler(lengths, 4, lengths=lengths)) + # The biggest element should be first + self.assertEqual(lengths[indices[0]], 50) + # The indices should be a permutation of range(100) + self.assertEqual(list(sorted(indices)), list(range(100))) + + def test_distributed_length_grouped(self): + # Get some inputs of random lengths + lengths = torch.randint(0, 25, (100,)).tolist() + # Put one bigger than the others to check it ends up in first position + lengths[32] = 50 + + indices_process_0 = list(DistributedLengthGroupedSampler(lengths, 4, 2, 0, lengths=lengths)) + indices_process_1 = list(DistributedLengthGroupedSampler(lengths, 4, 2, 1, lengths=lengths)) + # The biggest element should be first + self.assertEqual(lengths[indices_process_0[0]], 50) + # The indices should be a permutation of range(100) + self.assertEqual(list(sorted(indices_process_0 + indices_process_1)), list(range(100)))