From a0a027c2ed53b324cf4d0179ceec88d4ff414d47 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 16 Mar 2021 11:22:39 -0400 Subject: [PATCH] Add DistributedSamplerWithLoop (#10746) * Add DistributedSamplerWithLoop * Fix typo * Test and small fix --- src/transformers/sagemaker/trainer_sm.py | 8 ++++++ src/transformers/trainer.py | 35 +++++++++++------------- src/transformers/trainer_pt_utils.py | 30 +++++++++++++++++++- src/transformers/training_args.py | 14 ++++++++++ tests/test_trainer_utils.py | 26 ++++++++++++++++++ 5 files changed, 93 insertions(+), 20 deletions(-) diff --git a/src/transformers/sagemaker/trainer_sm.py b/src/transformers/sagemaker/trainer_sm.py index 202afb85cd..0d828b25aa 100644 --- a/src/transformers/sagemaker/trainer_sm.py +++ b/src/transformers/sagemaker/trainer_sm.py @@ -26,6 +26,7 @@ from ..modeling_utils import PreTrainedModel, unwrap_model from ..trainer import Trainer from ..trainer_pt_utils import ( DistributedLengthGroupedSampler, + DistributedSamplerWithLoop, SequentialDistributedSampler, nested_detach, nested_numpify, @@ -97,6 +98,13 @@ class SageMakerTrainer(Trainer): return DistributedLengthGroupedSampler( self.train_dataset, self.args.train_batch_size, num_replicas=smp.dp_size(), rank=smp.dp_rank() ) + elif not self.args.dataloader_drop_last: + return DistributedSamplerWithLoop( + self.train_dataset, + self.args.per_device_train_batch_size, + num_replicas=smp.dp_size(), + rank=smp.dp_rank(), + ) else: return DistributedSampler(self.train_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank()) else: diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 151c0c751e..794cddad77 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -77,6 +77,7 @@ from .trainer_callback import ( ) from .trainer_pt_utils import ( DistributedLengthGroupedSampler, + DistributedSamplerWithLoop, DistributedTensorGatherer, LabelSmoother, LengthGroupedSampler, @@ -491,24 +492,10 @@ class Trainer: ): return None - # Gather the number of processes and this process index. - if self.args.parallel_mode == ParallelMode.TPU: - num_processes = xm.xrt_world_size() - process_index = xm.get_ordinal() - elif ( - self.args.parallel_mode == ParallelMode.DISTRIBUTED - or self.args.parallel_mode == ParallelMode.SAGEMAKER_DISTRIBUTED - ): - num_processes = dist.get_world_size() - process_index = dist.get_rank() - else: - num_processes = 1 - process_index = 0 - # Build the sampler. if self.args.group_by_length: model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None - if num_processes <= 1: + if self.args.world_size <= 1: return LengthGroupedSampler( self.train_dataset, self.args.train_batch_size, model_input_name=model_input_name ) @@ -516,16 +503,26 @@ class Trainer: return DistributedLengthGroupedSampler( self.train_dataset, self.args.train_batch_size, - num_replicas=num_processes, - rank=process_index, + num_replicas=self.args.world_size, + rank=self.args.process_index, model_input_name=model_input_name, ) else: - if num_processes <= 1: + if self.args.world_size <= 1: return RandomSampler(self.train_dataset) + elif self.args.parallel_mode == ParallelMode.TPU and not self.args.dataloader_drop_last: + # Use a loop for TPUs when drop_last is False to have all batches have the same size. + return DistributedSamplerWithLoop( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + num_replicas=self.args.world_size, + rank=self.args.process_index, + ) else: - return DistributedSampler(self.train_dataset, num_replicas=num_processes, rank=process_index) + return DistributedSampler( + self.train_dataset, num_replicas=self.args.world_size, rank=self.args.process_index + ) def get_train_dataloader(self) -> DataLoader: """ diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index ae8e249490..673ed13ae8 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -182,6 +182,34 @@ def torch_distributed_zero_first(local_rank: int): dist.barrier() +class DistributedSamplerWithLoop(DistributedSampler): + """ + Like a :obj:torch.utils.data.distributed.DistributedSampler` but loops at the end back to the beginning of the + shuffled samples to make each process have a round multiple of batch_size samples. + + Args: + dataset (:obj:`torch.utils.data.Dataset`): + Dataset used for sampling. + batch_size (:obj:`int`): + The batch size used with this sampler + kwargs: + All other keyword arguments passed to :obj:`DistributedSampler`. + """ + + def __init__(self, dataset, batch_size, **kwargs): + super().__init__(dataset, **kwargs) + self.batch_size = batch_size + + def __iter__(self): + indices = list(super().__iter__()) + remainder = 0 if len(indices) % self.batch_size == 0 else self.batch_size - len(indices) % self.batch_size + # DistributedSampler already added samples from the beginning to make the number of samples a round multiple + # of the world size, so we skip those. + start_remainder = 1 if self.rank < len(self.dataset) % self.num_replicas else 0 + indices += indices[start_remainder : start_remainder + remainder] + return iter(indices) + + class SequentialDistributedSampler(Sampler): """ Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end. @@ -228,7 +256,7 @@ class SequentialDistributedSampler(Sampler): return self.num_samples -def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset): +def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset, bach_size: int): if xm.xrt_world_size() <= 1: return RandomSampler(dataset) return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 85d7fdd402..ea6885ca9e 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -742,6 +742,20 @@ class TrainingArguments: return torch.distributed.get_world_size() return 1 + @property + @torch_required + def process_index(self): + """ + The number of processes used in parallel. + """ + if is_torch_tpu_available(): + return xm.get_ordinal() + elif is_sagemaker_distributed_available(): + return sm_dist.get_rank() + elif self.local_rank != -1: + return torch.distributed.get_rank() + return 0 + @property def place_model_on_device(self): """ diff --git a/tests/test_trainer_utils.py b/tests/test_trainer_utils.py index f56ef140e8..5cd1c39f14 100644 --- a/tests/test_trainer_utils.py +++ b/tests/test_trainer_utils.py @@ -27,6 +27,7 @@ if is_torch_available(): from transformers.modeling_outputs import SequenceClassifierOutput from transformers.trainer_pt_utils import ( DistributedLengthGroupedSampler, + DistributedSamplerWithLoop, DistributedTensorGatherer, LabelSmoother, LengthGroupedSampler, @@ -141,3 +142,28 @@ class TrainerUtilsTest(unittest.TestCase): ['0.linear1.weight', '0.linear1.bias', '0.linear2.weight', '0.linear2.bias', '0.bias', '1.0.linear1.weight', '1.0.linear1.bias', '1.0.linear2.weight', '1.0.linear2.bias', '1.0.bias', '1.1.linear1.weight', '1.1.linear1.bias', '1.1.linear2.weight', '1.1.linear2.bias', '1.1.bias'] ) # fmt: on + + def test_distributed_sampler_with_loop(self): + batch_size = 16 + for length in [23, 64, 123]: + dataset = list(range(length)) + shard1 = DistributedSamplerWithLoop(dataset, batch_size, num_replicas=2, rank=0) + shard2 = DistributedSamplerWithLoop(dataset, batch_size, num_replicas=2, rank=1) + + # Set seeds + shard1.set_epoch(0) + shard2.set_epoch(0) + + # Sample + samples1 = list(shard1) + samples2 = list(shard2) + + self.assertTrue(len(samples1) % batch_size == 0) + self.assertTrue(len(samples2) % batch_size == 0) + + total = [] + for sample1, sample2 in zip(samples1, samples2): + total += [sample1, sample2] + + self.assertEqual(set(total[:length]), set(dataset)) + self.assertEqual(set(total[length:]), set(total[: (len(total) - length)]))