Trainer iterable dataset (#11254)
* IterableDatasetShard * Test and integration in Trainer * Update src/transformers/trainer_pt_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Style Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -23,12 +23,14 @@ from transformers.testing_utils import require_torch
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||
from transformers.trainer_pt_utils import (
|
||||
DistributedLengthGroupedSampler,
|
||||
DistributedSamplerWithLoop,
|
||||
DistributedTensorGatherer,
|
||||
IterableDatasetShard,
|
||||
LabelSmoother,
|
||||
LengthGroupedSampler,
|
||||
SequentialDistributedSampler,
|
||||
@@ -49,6 +51,22 @@ if is_torch_available():
|
||||
h = torch.nn.functional.relu(self.linear2(x))
|
||||
return self.ln2(x + h + self.bias)
|
||||
|
||||
class RandomIterableDataset(IterableDataset):
|
||||
# For testing, an iterable dataset of random length
|
||||
def __init__(self, p_stop=0.01, max_length=1000):
|
||||
self.p_stop = p_stop
|
||||
self.max_length = max_length
|
||||
self.generator = torch.Generator()
|
||||
|
||||
def __iter__(self):
|
||||
count = 0
|
||||
stop = False
|
||||
while not stop and count < self.max_length:
|
||||
yield count
|
||||
count += 1
|
||||
number = torch.rand(1, generator=self.generator).item()
|
||||
stop = number < self.p_stop
|
||||
|
||||
|
||||
@require_torch
|
||||
class TrainerUtilsTest(unittest.TestCase):
|
||||
@@ -243,3 +261,45 @@ class TrainerUtilsTest(unittest.TestCase):
|
||||
|
||||
self.assertListEqual(total[:length], dataset)
|
||||
self.assertListEqual(total[length:], dataset[: (len(total) - length)])
|
||||
|
||||
def check_iterable_dataset_shard(self, dataset, batch_size, drop_last, num_processes=2, epoch=0):
|
||||
# Set the seed for the base dataset to get the proper reference.
|
||||
dataset.generator.manual_seed(epoch)
|
||||
reference = list(dataset)
|
||||
|
||||
shards = [
|
||||
IterableDatasetShard(
|
||||
dataset, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i
|
||||
)
|
||||
for i in range(num_processes)
|
||||
]
|
||||
for shard in shards:
|
||||
shard.set_epoch(epoch)
|
||||
shard_lists = [list(shard) for shard in shards]
|
||||
|
||||
for shard in shard_lists:
|
||||
# All shards have a number of samples that is a round multiple of batch size
|
||||
self.assertTrue(len(shard) % batch_size == 0)
|
||||
# All shards have the same number of samples
|
||||
self.assertEqual(len(shard), len(shard_lists[0]))
|
||||
|
||||
observed = []
|
||||
for idx in range(0, len(shard_lists[0]), batch_size):
|
||||
for shard in shard_lists:
|
||||
observed += shard[idx : idx + batch_size]
|
||||
|
||||
# If drop_last is False we loop through samples at the beginning to have a size that is a round multiple of
|
||||
# batch_size
|
||||
if not drop_last:
|
||||
while len(reference) < len(observed):
|
||||
reference += reference
|
||||
self.assertListEqual(observed, reference[: len(observed)])
|
||||
|
||||
def test_iterable_dataset_shard(self):
|
||||
dataset = RandomIterableDataset()
|
||||
|
||||
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0)
|
||||
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0)
|
||||
|
||||
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)
|
||||
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)
|
||||
|
||||
Reference in New Issue
Block a user