From 0282e24eef9ce542e4f783121f68c377fd4e3984 Mon Sep 17 00:00:00 2001 From: Mansi Mane Date: Wed, 17 Mar 2021 16:18:11 -0700 Subject: [PATCH] Smmp batch not divisible by microbatches fix (#10778) * Added debug prints * Added config * Added prints * Added prints * Added extra samples to SequentialDistributedSampler * Added extra samples to SequentialDistributedSampler Updated SequentialDistributedSampler call * Added deubg prints * Removed extra prints * Making predicitons and labels multiple of batchsize * updated number of microbatches * Removed extra prints * Made start_remainder similar to DistributedSamplerWithLoop * Minor spacing update * Added debug prints Added config Added prints Added prints * Added extra samples to SequentialDistributedSampler Updated SequentialDistributedSampler call Added extra samples to SequentialDistributedSampler Added deubg prints Removed extra prints Making predicitons and labels multiple of batchsize updated number of microbatches Removed extra prints Squashing redundant commits * Made start_remainder similar to DistributedSamplerWithLoop Minor spacing update Made start_remainder similar to DistributedSamplerWithLoop * Test and styling * Rename test Co-authored-by: Sylvain Gugger --- src/transformers/sagemaker/trainer_sm.py | 7 ++++- src/transformers/trainer.py | 4 +-- src/transformers/trainer_pt_utils.py | 10 +++++-- tests/test_trainer_utils.py | 33 ++++++++++++++++++++++++ 4 files changed, 49 insertions(+), 5 deletions(-) diff --git a/src/transformers/sagemaker/trainer_sm.py b/src/transformers/sagemaker/trainer_sm.py index 0d828b25aa..95ee4cab61 100644 --- a/src/transformers/sagemaker/trainer_sm.py +++ b/src/transformers/sagemaker/trainer_sm.py @@ -112,7 +112,12 @@ class SageMakerTrainer(Trainer): def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]: if self.is_model_parallel_enabled: - return SequentialDistributedSampler(eval_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank()) + return SequentialDistributedSampler( + eval_dataset, + num_replicas=smp.dp_size(), + rank=smp.dp_rank(), + batch_size=self.args.per_device_eval_batch_size, + ) else: return super()._get_eval_sampler(eval_dataset) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index bf1a5e1731..a809cb7fa1 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1812,8 +1812,8 @@ class Trainer: eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) if not prediction_loss_only: - preds_gatherer = DistributedTensorGatherer(world_size, num_examples) - labels_gatherer = DistributedTensorGatherer(world_size, num_examples) + preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) + labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) model.eval() diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 673ed13ae8..fb0ca59531 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -220,7 +220,7 @@ class SequentialDistributedSampler(Sampler): or `reduce` resulting tensors at the end of the loop. """ - def __init__(self, dataset, num_replicas=None, rank=None): + def __init__(self, dataset, num_replicas=None, rank=None, batch_size=None): if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") @@ -232,8 +232,14 @@ class SequentialDistributedSampler(Sampler): self.dataset = dataset self.num_replicas = num_replicas self.rank = rank - self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + num_samples = len(self.dataset) + # Add extra samples to make num_samples a multiple of batch_size if passed + if batch_size is not None: + self.num_samples = int(math.ceil(num_samples / (batch_size * num_replicas))) * batch_size + else: + self.num_samples = int(math.ceil(num_samples / num_replicas)) self.total_size = self.num_samples * self.num_replicas + self.batch_size = batch_size def __iter__(self): indices = list(range(len(self.dataset))) diff --git a/tests/test_trainer_utils.py b/tests/test_trainer_utils.py index 5cd1c39f14..5d0672794b 100644 --- a/tests/test_trainer_utils.py +++ b/tests/test_trainer_utils.py @@ -31,6 +31,7 @@ if is_torch_available(): DistributedTensorGatherer, LabelSmoother, LengthGroupedSampler, + SequentialDistributedSampler, get_parameter_names, ) @@ -167,3 +168,35 @@ class TrainerUtilsTest(unittest.TestCase): self.assertEqual(set(total[:length]), set(dataset)) self.assertEqual(set(total[length:]), set(total[: (len(total) - length)])) + + def test_sequential_distributed_sampler(self): + batch_size = 16 + for length in [23, 64, 123]: + dataset = list(range(length)) + shard1 = SequentialDistributedSampler(dataset, num_replicas=2, rank=0) + shard2 = SequentialDistributedSampler(dataset, num_replicas=2, rank=1) + + # Sample + samples1 = list(shard1) + samples2 = list(shard2) + + total = samples1 + samples2 + + self.assertListEqual(total[:length], dataset) + self.assertListEqual(total[length:], dataset[: (len(total) - length)]) + + # With a batch_size passed + shard1 = SequentialDistributedSampler(dataset, num_replicas=2, rank=0, batch_size=batch_size) + shard2 = SequentialDistributedSampler(dataset, num_replicas=2, rank=1, batch_size=batch_size) + + # Sample + samples1 = list(shard1) + samples2 = list(shard2) + + self.assertTrue(len(samples1) % batch_size == 0) + self.assertTrue(len(samples2) % batch_size == 0) + + total = samples1 + samples2 + + self.assertListEqual(total[:length], dataset) + self.assertListEqual(total[length:], dataset[: (len(total) - length)])