Allow dataset to be an optional argument for (Distributed)LengthGroupedSampler (#13820)

* Allow dataset to be an optional argument for (Distributed)LengthGroupedSampler

* Fix
This commit is contained in:
Zhaofeng Wu
2021-10-05 06:04:39 -07:00
committed by GitHub
parent d4e4efce68
commit 1b74af76b7
3 changed files with 35 additions and 30 deletions

View File

@@ -181,7 +181,7 @@ class TrainerUtilsTest(unittest.TestCase):
# Put one bigger than the others to check it ends up in first position
lengths[32] = 50
indices = list(LengthGroupedSampler(lengths, 4, lengths=lengths))
indices = list(LengthGroupedSampler(4, lengths=lengths))
# The biggest element should be first
self.assertEqual(lengths[indices[0]], 50)
# The indices should be a permutation of range(100)
@@ -196,7 +196,7 @@ class TrainerUtilsTest(unittest.TestCase):
# Put one bigger than the others to check it ends up in first position
data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist()
indices = list(LengthGroupedSampler(data, 4))
indices = list(LengthGroupedSampler(4, dataset=data))
# The biggest element should be first
self.assertEqual(len(data[indices[0]]["input_ids"]), 105)
# The indices should be a permutation of range(6)
@@ -211,7 +211,7 @@ class TrainerUtilsTest(unittest.TestCase):
# Put one bigger than the others to check it ends up in first position
data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist()
indices = list(LengthGroupedSampler(data, 4))
indices = list(LengthGroupedSampler(4, dataset=data))
# The biggest element should be first
self.assertEqual(len(data[indices[0]]["input_ids"]), 105)
# The indices should be a permutation of range(6)
@@ -223,8 +223,8 @@ class TrainerUtilsTest(unittest.TestCase):
# Put one bigger than the others to check it ends up in first position
lengths[32] = 50
indices_process_0 = list(DistributedLengthGroupedSampler(lengths, 4, 2, 0, lengths=lengths))
indices_process_1 = list(DistributedLengthGroupedSampler(lengths, 4, 2, 1, lengths=lengths))
indices_process_0 = list(DistributedLengthGroupedSampler(4, num_replicas=2, rank=0, lengths=lengths))
indices_process_1 = list(DistributedLengthGroupedSampler(4, num_replicas=2, rank=1, lengths=lengths))
# The biggest element should be first
self.assertEqual(lengths[indices_process_0[0]], 50)
# The indices should be a permutation of range(100)