Trainer iterable dataset (#11254)

* IterableDatasetShard

* Test and integration in Trainer

* Update src/transformers/trainer_pt_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Style

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Sylvain Gugger
2021-04-14 17:02:26 -04:00
committed by GitHub
parent 83206ca6a8
commit aaaed56ffc
4 changed files with 185 additions and 40 deletions

View File

@@ -44,9 +44,7 @@ if is_torch_available():
from torch.utils.data import IterableDataset
from transformers import (
AutoModelForMaskedLM,
AutoModelForSequenceClassification,
DataCollatorForLanguageModeling,
EarlyStoppingCallback,
GlueDataset,
GlueDataTrainingArguments,
@@ -54,7 +52,6 @@ if is_torch_available():
GPT2LMHeadModel,
LineByLineTextDataset,
PreTrainedModel,
TextDataset,
Trainer,
TrainerState,
)
@@ -138,16 +135,12 @@ class RegressionModelConfig(PretrainedConfig):
if is_torch_available():
class SampleIterableDataset(IterableDataset):
"""
Criteria is not whether it is IterableDataset or not, criteria is whether __len__ is implemented
"""
def __init__(self, file_path, tokenizer):
self.ds = TextDataset(file_path=file_path, tokenizer=tokenizer, block_size=64)
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
self.dataset = RegressionDataset(a=a, b=b, length=length, seed=seed, label_names=label_names)
def __iter__(self):
for i in range(len(self.ds)):
yield self.ds[i]
for i in range(len(self.dataset)):
yield self.dataset[i]
class RegressionModel(torch.nn.Module):
def __init__(self, a=0, b=0, double_output=False):
@@ -827,18 +820,12 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(len(dataset), 31)
def test_trainer_iterable_dataset(self):
# Simulate Language Modeling with an IterableDataset, with no __len__ method
# Pick-up a tiny model, so it works on CPU
# See Issue #5990: https://github.com/huggingface/transformers/issues/5990
MODEL_ID = "sshleifer/tiny-distilbert-base-cased"
model = AutoModelForMaskedLM.from_pretrained(MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
train_dataset = SampleIterableDataset(file_path=PATH_SAMPLE_TEXT, tokenizer=tokenizer)
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
config = RegressionModelConfig()
model = RegressionPreTrainedModel(config)
train_dataset = SampleIterableDataset()
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator)
args = RegressionTrainingArguments(output_dir="./examples", max_steps=2)
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
trainer.train()
loader = trainer.get_train_dataloader()
@@ -847,30 +834,19 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# Exception if giving iterable dataset and no max_steps
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
_ = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator)
args1 = RegressionTrainingArguments(output_dir="./examples")
_ = Trainer(model=model, args=args1, train_dataset=train_dataset)
# Exception if eval_dataset is iterable in __init__
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
_ = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=train_dataset,
data_collator=data_collator,
)
_ = Trainer(model=model, args=args, train_dataset=train_dataset, eval_dataset=train_dataset)
# Exception if predicting with iterable dataset
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
trainer = Trainer(model=model, args=training_args, data_collator=data_collator)
trainer.predict(train_dataset)
# Exception if evaluating with iterable dataset
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
trainer = Trainer(model=model, args=training_args, data_collator=data_collator)
trainer.evaluate(train_dataset)
def test_num_train_epochs_in_training(self):

View File

@@ -23,12 +23,14 @@ from transformers.testing_utils import require_torch
if is_torch_available():
import torch
from torch.utils.data import IterableDataset
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.trainer_pt_utils import (
DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
DistributedTensorGatherer,
IterableDatasetShard,
LabelSmoother,
LengthGroupedSampler,
SequentialDistributedSampler,
@@ -49,6 +51,22 @@ if is_torch_available():
h = torch.nn.functional.relu(self.linear2(x))
return self.ln2(x + h + self.bias)
class RandomIterableDataset(IterableDataset):
# For testing, an iterable dataset of random length
def __init__(self, p_stop=0.01, max_length=1000):
self.p_stop = p_stop
self.max_length = max_length
self.generator = torch.Generator()
def __iter__(self):
count = 0
stop = False
while not stop and count < self.max_length:
yield count
count += 1
number = torch.rand(1, generator=self.generator).item()
stop = number < self.p_stop
@require_torch
class TrainerUtilsTest(unittest.TestCase):
@@ -243,3 +261,45 @@ class TrainerUtilsTest(unittest.TestCase):
self.assertListEqual(total[:length], dataset)
self.assertListEqual(total[length:], dataset[: (len(total) - length)])
def check_iterable_dataset_shard(self, dataset, batch_size, drop_last, num_processes=2, epoch=0):
# Set the seed for the base dataset to get the proper reference.
dataset.generator.manual_seed(epoch)
reference = list(dataset)
shards = [
IterableDatasetShard(
dataset, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i
)
for i in range(num_processes)
]
for shard in shards:
shard.set_epoch(epoch)
shard_lists = [list(shard) for shard in shards]
for shard in shard_lists:
# All shards have a number of samples that is a round multiple of batch size
self.assertTrue(len(shard) % batch_size == 0)
# All shards have the same number of samples
self.assertEqual(len(shard), len(shard_lists[0]))
observed = []
for idx in range(0, len(shard_lists[0]), batch_size):
for shard in shard_lists:
observed += shard[idx : idx + batch_size]
# If drop_last is False we loop through samples at the beginning to have a size that is a round multiple of
# batch_size
if not drop_last:
while len(reference) < len(observed):
reference += reference
self.assertListEqual(observed, reference[: len(observed)])
def test_iterable_dataset_shard(self):
dataset = RandomIterableDataset()
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0)
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0)
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)