Trainer iterable dataset (#11254)
* IterableDatasetShard * Test and integration in Trainer * Update src/transformers/trainer_pt_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Style Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -81,6 +81,7 @@ from .trainer_pt_utils import (
|
||||
DistributedLengthGroupedSampler,
|
||||
DistributedSamplerWithLoop,
|
||||
DistributedTensorGatherer,
|
||||
IterableDatasetShard,
|
||||
LabelSmoother,
|
||||
LengthGroupedSampler,
|
||||
SequentialDistributedSampler,
|
||||
@@ -493,9 +494,7 @@ class Trainer:
|
||||
dataset.set_format(type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"])
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset) or not isinstance(
|
||||
self.train_dataset, collections.abc.Sized
|
||||
):
|
||||
if not isinstance(self.train_dataset, collections.abc.Sized):
|
||||
return None
|
||||
|
||||
# Build the sampler.
|
||||
@@ -553,6 +552,26 @@ class Trainer:
|
||||
"""
|
||||
if self.train_dataset is None:
|
||||
raise ValueError("Trainer: training requires a train_dataset.")
|
||||
|
||||
if isinstance(self.train_dataset, torch.utils.data.dataset.IterableDataset):
|
||||
if self.args.world_size > 1:
|
||||
train_dataset = IterableDatasetShard(
|
||||
self.train_dataset,
|
||||
batch_size=self.args.train_batch_size,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
num_processes=self.args.world_size,
|
||||
process_index=self.args.process_index,
|
||||
)
|
||||
else:
|
||||
train_dataset = self.train_dataset
|
||||
return DataLoader(
|
||||
train_dataset,
|
||||
batch_size=self.args.train_batch_size,
|
||||
collate_fn=self.data_collator,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
pin_memory=self.args.dataloader_pin_memory,
|
||||
)
|
||||
|
||||
train_sampler = self._get_train_sampler()
|
||||
|
||||
return DataLoader(
|
||||
|
||||
@@ -28,7 +28,7 @@ from typing import Dict, Iterator, List, Optional, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torch.utils.data.dataset import Dataset, IterableDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import RandomSampler, Sampler
|
||||
|
||||
@@ -576,6 +576,96 @@ class DistributedLengthGroupedSampler(DistributedSampler):
|
||||
return iter(indices)
|
||||
|
||||
|
||||
class IterableDatasetShard(IterableDataset):
|
||||
"""
|
||||
Wraps a PyTorch :obj:`IterableDataset` to generate samples for one of the processes only. Instances of this class
|
||||
will always yield a number of samples that is a round multiple of the actual batch size (which is :obj:`batch_size
|
||||
x num_processes`). Depending on the value of the :obj:`drop_last` attribute, it will either stop the iteration at
|
||||
the first batch that would be too small or loop with indices from the beginning.
|
||||
|
||||
On two processes with an iterable dataset yielding of :obj:`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]` with a batch
|
||||
size of 2:
|
||||
|
||||
- the shard on process 0 will yield :obj:`[0, 1, 4, 5, 8, 9]` so will see batches :obj:`[0, 1]`, :obj:`[4, 5]`,
|
||||
:obj:`[8, 9]`
|
||||
- the shard on process 1 will yield :obj:`[2, 3, 6, 7, 10, 11]` so will see batches :obj:`[2, 3]`, :obj:`[6, 7]`,
|
||||
:obj:`[10, 11]`
|
||||
|
||||
.. warning:
|
||||
|
||||
If your IterableDataset implements some randomization that needs to be applied the same way on all processes
|
||||
(for instance, a shuffling), you should use a :obj:`torch.Generator` in a :obj:`generator` attribute of the
|
||||
:obj:`dataset` to generate your random numbers and call the
|
||||
:meth:`~transformers.trainer_pt_utils.IterableDatasetShard.set_epoch` method of this object. It will set the
|
||||
seed of this :obj:`generator` to :obj:`seed + epoch` on all processes before starting the iteration.
|
||||
Alternatively, you can also subclass this class and override the :meth:`__iter__` method with your custom
|
||||
logic.
|
||||
|
||||
Args:
|
||||
dataset (:obj:`torch.utils.data.dataset.IterableDataset`):
|
||||
The batch sampler to split in several shards.
|
||||
batch_size (:obj:`int`, `optional`, defaults to 1):
|
||||
The size of the batches per shard.
|
||||
drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the
|
||||
beginning.
|
||||
num_processes (:obj:`int`, `optional`, defaults to 1):
|
||||
The number of processes running concurrently.
|
||||
process_index (:obj:`int`, `optional`, defaults to 0):
|
||||
The index of the current process.
|
||||
seed (:obj:`int`, `optional`, defaults to 0):
|
||||
A random seed that will be used for the random number generation in
|
||||
:meth:`~transformers.trainer_pt_utils.IterableDatasetShard.set_epoch`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: IterableDataset,
|
||||
batch_size: int = 1,
|
||||
drop_last: bool = False,
|
||||
num_processes: int = 1,
|
||||
process_index: int = 0,
|
||||
seed: int = 0,
|
||||
):
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
self.drop_last = drop_last
|
||||
self.num_processes = num_processes
|
||||
self.process_index = process_index
|
||||
self.seed = seed
|
||||
self.epoch = 0
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
def __iter__(self):
|
||||
if hasattr(self.dataset, "generator") and isinstance(self.dataset.generator, torch.Generator):
|
||||
self.dataset.generator.manual_seed(self.seed + self.epoch)
|
||||
real_batch_size = self.batch_size * self.num_processes
|
||||
process_slice = range(self.process_index * self.batch_size, (self.process_index + 1) * self.batch_size)
|
||||
|
||||
first_batch = None
|
||||
current_batch = []
|
||||
for element in self.dataset:
|
||||
current_batch.append(element)
|
||||
# Wait to have a full batch before yielding elements.
|
||||
if len(current_batch) == real_batch_size:
|
||||
for i in process_slice:
|
||||
yield current_batch[i]
|
||||
if first_batch is None:
|
||||
first_batch = current_batch.copy()
|
||||
current_batch = []
|
||||
|
||||
# Finished if drop_last is True, otherwise complete the last batch with elements from the beginning.
|
||||
if not self.drop_last and len(current_batch) > 0:
|
||||
if first_batch is None:
|
||||
first_batch = current_batch.copy()
|
||||
while len(current_batch) < real_batch_size:
|
||||
current_batch += first_batch
|
||||
for i in process_slice:
|
||||
yield current_batch[i]
|
||||
|
||||
|
||||
# In order to keep `trainer.py` compact and easy to understand, place any secondary PT Trainer
|
||||
# helper methods here
|
||||
|
||||
|
||||
@@ -44,9 +44,7 @@ if is_torch_available():
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
from transformers import (
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForSequenceClassification,
|
||||
DataCollatorForLanguageModeling,
|
||||
EarlyStoppingCallback,
|
||||
GlueDataset,
|
||||
GlueDataTrainingArguments,
|
||||
@@ -54,7 +52,6 @@ if is_torch_available():
|
||||
GPT2LMHeadModel,
|
||||
LineByLineTextDataset,
|
||||
PreTrainedModel,
|
||||
TextDataset,
|
||||
Trainer,
|
||||
TrainerState,
|
||||
)
|
||||
@@ -138,16 +135,12 @@ class RegressionModelConfig(PretrainedConfig):
|
||||
if is_torch_available():
|
||||
|
||||
class SampleIterableDataset(IterableDataset):
|
||||
"""
|
||||
Criteria is not whether it is IterableDataset or not, criteria is whether __len__ is implemented
|
||||
"""
|
||||
|
||||
def __init__(self, file_path, tokenizer):
|
||||
self.ds = TextDataset(file_path=file_path, tokenizer=tokenizer, block_size=64)
|
||||
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
|
||||
self.dataset = RegressionDataset(a=a, b=b, length=length, seed=seed, label_names=label_names)
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(len(self.ds)):
|
||||
yield self.ds[i]
|
||||
for i in range(len(self.dataset)):
|
||||
yield self.dataset[i]
|
||||
|
||||
class RegressionModel(torch.nn.Module):
|
||||
def __init__(self, a=0, b=0, double_output=False):
|
||||
@@ -827,18 +820,12 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(len(dataset), 31)
|
||||
|
||||
def test_trainer_iterable_dataset(self):
|
||||
# Simulate Language Modeling with an IterableDataset, with no __len__ method
|
||||
# Pick-up a tiny model, so it works on CPU
|
||||
# See Issue #5990: https://github.com/huggingface/transformers/issues/5990
|
||||
MODEL_ID = "sshleifer/tiny-distilbert-base-cased"
|
||||
model = AutoModelForMaskedLM.from_pretrained(MODEL_ID)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
train_dataset = SampleIterableDataset(file_path=PATH_SAMPLE_TEXT, tokenizer=tokenizer)
|
||||
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
|
||||
config = RegressionModelConfig()
|
||||
model = RegressionPreTrainedModel(config)
|
||||
train_dataset = SampleIterableDataset()
|
||||
|
||||
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
|
||||
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator)
|
||||
args = RegressionTrainingArguments(output_dir="./examples", max_steps=2)
|
||||
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
|
||||
loader = trainer.get_train_dataloader()
|
||||
@@ -847,30 +834,19 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
|
||||
# Exception if giving iterable dataset and no max_steps
|
||||
with self.assertRaises(ValueError):
|
||||
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
|
||||
_ = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator)
|
||||
args1 = RegressionTrainingArguments(output_dir="./examples")
|
||||
_ = Trainer(model=model, args=args1, train_dataset=train_dataset)
|
||||
|
||||
# Exception if eval_dataset is iterable in __init__
|
||||
with self.assertRaises(ValueError):
|
||||
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
|
||||
_ = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=train_dataset,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
_ = Trainer(model=model, args=args, train_dataset=train_dataset, eval_dataset=train_dataset)
|
||||
|
||||
# Exception if predicting with iterable dataset
|
||||
with self.assertRaises(ValueError):
|
||||
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
|
||||
trainer = Trainer(model=model, args=training_args, data_collator=data_collator)
|
||||
trainer.predict(train_dataset)
|
||||
|
||||
# Exception if evaluating with iterable dataset
|
||||
with self.assertRaises(ValueError):
|
||||
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
|
||||
trainer = Trainer(model=model, args=training_args, data_collator=data_collator)
|
||||
trainer.evaluate(train_dataset)
|
||||
|
||||
def test_num_train_epochs_in_training(self):
|
||||
|
||||
@@ -23,12 +23,14 @@ from transformers.testing_utils import require_torch
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||
from transformers.trainer_pt_utils import (
|
||||
DistributedLengthGroupedSampler,
|
||||
DistributedSamplerWithLoop,
|
||||
DistributedTensorGatherer,
|
||||
IterableDatasetShard,
|
||||
LabelSmoother,
|
||||
LengthGroupedSampler,
|
||||
SequentialDistributedSampler,
|
||||
@@ -49,6 +51,22 @@ if is_torch_available():
|
||||
h = torch.nn.functional.relu(self.linear2(x))
|
||||
return self.ln2(x + h + self.bias)
|
||||
|
||||
class RandomIterableDataset(IterableDataset):
|
||||
# For testing, an iterable dataset of random length
|
||||
def __init__(self, p_stop=0.01, max_length=1000):
|
||||
self.p_stop = p_stop
|
||||
self.max_length = max_length
|
||||
self.generator = torch.Generator()
|
||||
|
||||
def __iter__(self):
|
||||
count = 0
|
||||
stop = False
|
||||
while not stop and count < self.max_length:
|
||||
yield count
|
||||
count += 1
|
||||
number = torch.rand(1, generator=self.generator).item()
|
||||
stop = number < self.p_stop
|
||||
|
||||
|
||||
@require_torch
|
||||
class TrainerUtilsTest(unittest.TestCase):
|
||||
@@ -243,3 +261,45 @@ class TrainerUtilsTest(unittest.TestCase):
|
||||
|
||||
self.assertListEqual(total[:length], dataset)
|
||||
self.assertListEqual(total[length:], dataset[: (len(total) - length)])
|
||||
|
||||
def check_iterable_dataset_shard(self, dataset, batch_size, drop_last, num_processes=2, epoch=0):
|
||||
# Set the seed for the base dataset to get the proper reference.
|
||||
dataset.generator.manual_seed(epoch)
|
||||
reference = list(dataset)
|
||||
|
||||
shards = [
|
||||
IterableDatasetShard(
|
||||
dataset, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i
|
||||
)
|
||||
for i in range(num_processes)
|
||||
]
|
||||
for shard in shards:
|
||||
shard.set_epoch(epoch)
|
||||
shard_lists = [list(shard) for shard in shards]
|
||||
|
||||
for shard in shard_lists:
|
||||
# All shards have a number of samples that is a round multiple of batch size
|
||||
self.assertTrue(len(shard) % batch_size == 0)
|
||||
# All shards have the same number of samples
|
||||
self.assertEqual(len(shard), len(shard_lists[0]))
|
||||
|
||||
observed = []
|
||||
for idx in range(0, len(shard_lists[0]), batch_size):
|
||||
for shard in shard_lists:
|
||||
observed += shard[idx : idx + batch_size]
|
||||
|
||||
# If drop_last is False we loop through samples at the beginning to have a size that is a round multiple of
|
||||
# batch_size
|
||||
if not drop_last:
|
||||
while len(reference) < len(observed):
|
||||
reference += reference
|
||||
self.assertListEqual(observed, reference[: len(observed)])
|
||||
|
||||
def test_iterable_dataset_shard(self):
|
||||
dataset = RandomIterableDataset()
|
||||
|
||||
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0)
|
||||
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0)
|
||||
|
||||
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)
|
||||
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)
|
||||
|
||||
Reference in New Issue
Block a user