Trainer support for IterableDataset for evaluation and predict (#11286)

* Bulk of the work

* Polish and tests

* Update QA Trainer

* Avoid breaking the predict method

* Deprecation warnings

* Store real eval dataloder

* Get eval dataset reference before wrap
This commit is contained in:
Sylvain Gugger
2021-04-16 16:01:58 -04:00
committed by GitHub
parent e783ea7304
commit d9c62047a8
8 changed files with 608 additions and 169 deletions

View File

@@ -819,35 +819,64 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
)
self.assertEqual(len(dataset), 31)
def test_trainer_iterable_dataset(self):
def test_training_iterable_dataset(self):
config = RegressionModelConfig()
model = RegressionPreTrainedModel(config)
train_dataset = SampleIterableDataset()
args = RegressionTrainingArguments(output_dir="./examples", max_steps=2)
args = RegressionTrainingArguments(output_dir="./examples", max_steps=4)
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
trainer.train()
self.assertEqual(trainer.state.global_step, 4)
loader = trainer.get_train_dataloader()
self.assertIsInstance(loader, torch.utils.data.DataLoader)
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
# Exception if giving iterable dataset and no max_steps
with self.assertRaises(ValueError):
args1 = RegressionTrainingArguments(output_dir="./examples")
_ = Trainer(model=model, args=args1, train_dataset=train_dataset)
def test_evaluation_iterable_dataset(self):
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
# Exception if eval_dataset is iterable in __init__
with self.assertRaises(ValueError):
_ = Trainer(model=model, args=args, train_dataset=train_dataset, eval_dataset=train_dataset)
args = RegressionTrainingArguments(output_dir="./examples")
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
results = trainer.evaluate()
# Exception if predicting with iterable dataset
with self.assertRaises(ValueError):
trainer.predict(train_dataset)
x, y = trainer.eval_dataset.dataset.x, trainer.eval_dataset.dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
# Exception if evaluating with iterable dataset
with self.assertRaises(ValueError):
trainer.evaluate(train_dataset)
# With a number of elements not a round multiple of the batch size
eval_dataset = SampleIterableDataset(length=66)
results = trainer.evaluate(eval_dataset)
x, y = eval_dataset.dataset.x, eval_dataset.dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
def test_predict_iterable_dataset(self):
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
args = RegressionTrainingArguments(output_dir="./examples")
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
preds = trainer.predict(trainer.eval_dataset).predictions
x = eval_dataset.dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
# With a number of elements not a round multiple of the batch size
test_dataset = SampleIterableDataset(length=66)
preds = trainer.predict(test_dataset).predictions
x = test_dataset.dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
def test_num_train_epochs_in_training(self):
# len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given.