Trainer support for IterableDataset for evaluation and predict (#11286)

* Bulk of the work

* Polish and tests

* Update QA Trainer

* Avoid breaking the predict method

* Deprecation warnings

* Store real eval dataloder

* Get eval dataset reference before wrap
This commit is contained in:
Sylvain Gugger
2021-04-16 16:01:58 -04:00
committed by GitHub
parent e783ea7304
commit d9c62047a8
8 changed files with 608 additions and 169 deletions

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import unittest
import numpy as np
@@ -34,6 +35,7 @@ if is_torch_available():
LabelSmoother,
LengthGroupedSampler,
SequentialDistributedSampler,
ShardSampler,
get_parameter_names,
)
@@ -283,6 +285,10 @@ class TrainerUtilsTest(unittest.TestCase):
# All shards have the same number of samples
self.assertEqual(len(shard), len(shard_lists[0]))
for shard in shards:
# All shards know the total number of samples
self.assertEqual(shard.num_examples, len(reference))
observed = []
for idx in range(0, len(shard_lists[0]), batch_size):
for shard in shard_lists:
@@ -295,11 +301,62 @@ class TrainerUtilsTest(unittest.TestCase):
reference += reference
self.assertListEqual(observed, reference[: len(observed)])
# Check equivalence between IterableDataset and ShardSampler
dataset.generator.manual_seed(epoch)
reference = list(dataset)
sampler_shards = [
ShardSampler(
reference, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i
)
for i in range(num_processes)
]
for shard, sampler_shard in zip(shard_lists, sampler_shards):
self.assertListEqual(shard, list(sampler_shard))
def test_iterable_dataset_shard(self):
dataset = RandomIterableDataset()
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0)
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0)
self.check_iterable_dataset_shard(dataset, 4, drop_last=False, num_processes=2, epoch=0)
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)
self.check_iterable_dataset_shard(dataset, 4, drop_last=False, num_processes=3, epoch=42)
def check_shard_sampler(self, dataset, batch_size, drop_last, num_processes=2):
shards = [
ShardSampler(
dataset, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i
)
for i in range(num_processes)
]
shard_lists = [list(shard) for shard in shards]
for shard in shard_lists:
# All shards have a number of samples that is a round multiple of batch size
self.assertTrue(len(shard) % batch_size == 0)
# All shards have the same number of samples
self.assertEqual(len(shard), len(shard_lists[0]))
observed = []
for idx in range(0, len(shard_lists[0]), batch_size):
for shard in shard_lists:
observed += shard[idx : idx + batch_size]
# If drop_last is False we loop through samples at the beginning to have a size that is a round multiple of
# batch_size
reference = copy.copy(dataset)
if not drop_last:
while len(reference) < len(observed):
reference += reference
self.assertListEqual(observed, reference[: len(observed)])
def test_shard_sampler(self):
for n_elements in [64, 123]:
dataset = list(range(n_elements))
self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=2)
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=2)
self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=3)
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=3)