Smmp batch not divisible by microbatches fix (#10778)

* Added debug prints

* Added config

* Added prints

* Added prints

* Added extra samples to SequentialDistributedSampler

* Added extra samples to SequentialDistributedSampler

Updated SequentialDistributedSampler call

* Added deubg prints

* Removed extra prints

* Making predicitons and labels multiple of batchsize

* updated number of microbatches

* Removed extra prints

* Made start_remainder similar to DistributedSamplerWithLoop

* Minor spacing update

* Added debug prints

Added config

Added prints

Added prints

* Added extra samples to SequentialDistributedSampler

Updated SequentialDistributedSampler call

Added extra samples to SequentialDistributedSampler

Added deubg prints

Removed extra prints

Making predicitons and labels multiple of batchsize

updated number of microbatches

Removed extra prints

Squashing redundant commits

* Made start_remainder similar to DistributedSamplerWithLoop

Minor spacing update

Made start_remainder similar to DistributedSamplerWithLoop

* Test and styling

* Rename test

Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
This commit is contained in:
Mansi Mane
2021-03-17 16:18:11 -07:00
committed by GitHub
parent 40b049c701
commit 0282e24eef
4 changed files with 49 additions and 5 deletions

View File

@@ -31,6 +31,7 @@ if is_torch_available():
DistributedTensorGatherer,
LabelSmoother,
LengthGroupedSampler,
SequentialDistributedSampler,
get_parameter_names,
)
@@ -167,3 +168,35 @@ class TrainerUtilsTest(unittest.TestCase):
self.assertEqual(set(total[:length]), set(dataset))
self.assertEqual(set(total[length:]), set(total[: (len(total) - length)]))
def test_sequential_distributed_sampler(self):
batch_size = 16
for length in [23, 64, 123]:
dataset = list(range(length))
shard1 = SequentialDistributedSampler(dataset, num_replicas=2, rank=0)
shard2 = SequentialDistributedSampler(dataset, num_replicas=2, rank=1)
# Sample
samples1 = list(shard1)
samples2 = list(shard2)
total = samples1 + samples2
self.assertListEqual(total[:length], dataset)
self.assertListEqual(total[length:], dataset[: (len(total) - length)])
# With a batch_size passed
shard1 = SequentialDistributedSampler(dataset, num_replicas=2, rank=0, batch_size=batch_size)
shard2 = SequentialDistributedSampler(dataset, num_replicas=2, rank=1, batch_size=batch_size)
# Sample
samples1 = list(shard1)
samples2 = list(shard2)
self.assertTrue(len(samples1) % batch_size == 0)
self.assertTrue(len(samples2) % batch_size == 0)
total = samples1 + samples2
self.assertListEqual(total[:length], dataset)
self.assertListEqual(total[length:], dataset[: (len(total) - length)])