Trainer support for IterableDataset for evaluation and predict (#11286)
* Bulk of the work * Polish and tests * Update QA Trainer * Avoid breaking the predict method * Deprecation warnings * Store real eval dataloder * Get eval dataset reference before wrap
This commit is contained in:
@@ -819,35 +819,64 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
)
|
||||
self.assertEqual(len(dataset), 31)
|
||||
|
||||
def test_trainer_iterable_dataset(self):
|
||||
def test_training_iterable_dataset(self):
|
||||
config = RegressionModelConfig()
|
||||
model = RegressionPreTrainedModel(config)
|
||||
train_dataset = SampleIterableDataset()
|
||||
|
||||
args = RegressionTrainingArguments(output_dir="./examples", max_steps=2)
|
||||
args = RegressionTrainingArguments(output_dir="./examples", max_steps=4)
|
||||
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
self.assertEqual(trainer.state.global_step, 4)
|
||||
|
||||
loader = trainer.get_train_dataloader()
|
||||
self.assertIsInstance(loader, torch.utils.data.DataLoader)
|
||||
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
|
||||
|
||||
# Exception if giving iterable dataset and no max_steps
|
||||
with self.assertRaises(ValueError):
|
||||
args1 = RegressionTrainingArguments(output_dir="./examples")
|
||||
_ = Trainer(model=model, args=args1, train_dataset=train_dataset)
|
||||
def test_evaluation_iterable_dataset(self):
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
|
||||
# Exception if eval_dataset is iterable in __init__
|
||||
with self.assertRaises(ValueError):
|
||||
_ = Trainer(model=model, args=args, train_dataset=train_dataset, eval_dataset=train_dataset)
|
||||
args = RegressionTrainingArguments(output_dir="./examples")
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
|
||||
results = trainer.evaluate()
|
||||
|
||||
# Exception if predicting with iterable dataset
|
||||
with self.assertRaises(ValueError):
|
||||
trainer.predict(train_dataset)
|
||||
x, y = trainer.eval_dataset.dataset.x, trainer.eval_dataset.dataset.ys[0]
|
||||
pred = 1.5 * x + 2.5
|
||||
expected_loss = ((pred - y) ** 2).mean()
|
||||
self.assertAlmostEqual(results["eval_loss"], expected_loss)
|
||||
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
|
||||
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||
|
||||
# Exception if evaluating with iterable dataset
|
||||
with self.assertRaises(ValueError):
|
||||
trainer.evaluate(train_dataset)
|
||||
# With a number of elements not a round multiple of the batch size
|
||||
eval_dataset = SampleIterableDataset(length=66)
|
||||
results = trainer.evaluate(eval_dataset)
|
||||
|
||||
x, y = eval_dataset.dataset.x, eval_dataset.dataset.ys[0]
|
||||
pred = 1.5 * x + 2.5
|
||||
expected_loss = ((pred - y) ** 2).mean()
|
||||
self.assertAlmostEqual(results["eval_loss"], expected_loss)
|
||||
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
|
||||
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||
|
||||
def test_predict_iterable_dataset(self):
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
|
||||
args = RegressionTrainingArguments(output_dir="./examples")
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
|
||||
|
||||
preds = trainer.predict(trainer.eval_dataset).predictions
|
||||
x = eval_dataset.dataset.x
|
||||
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
|
||||
|
||||
# With a number of elements not a round multiple of the batch size
|
||||
test_dataset = SampleIterableDataset(length=66)
|
||||
preds = trainer.predict(test_dataset).predictions
|
||||
x = test_dataset.dataset.x
|
||||
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
|
||||
|
||||
def test_num_train_epochs_in_training(self):
|
||||
# len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given.
|
||||
|
||||
Reference in New Issue
Block a user