Remove unnecessary columns for all dataset types in Trainer (#17166)

* Remove unneeded columns for IterableDataset

* Add test

* Update trainer tests

* Edit docstring

* Lint

* Apply feedback

* Apply feedback
This commit is contained in:
Antoni Baum
2022-05-11 17:11:26 +02:00
committed by GitHub
parent c33f6046c3
commit edcc66d27c
5 changed files with 118 additions and 22 deletions

View File

@@ -1329,7 +1329,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
def test_training_iterable_dataset(self):
config = RegressionModelConfig()
model = RegressionPreTrainedModel(config)
train_dataset = SampleIterableDataset()
# Adding one column not used by the model should have no impact
train_dataset = SampleIterableDataset(label_names=["labels", "extra"])
args = RegressionTrainingArguments(output_dir="./examples", max_steps=4)
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
@@ -1363,7 +1364,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
def test_evaluation_iterable_dataset(self):
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
# Adding one column not used by the model should have no impact
eval_dataset = SampleIterableDataset(label_names=["labels", "extra"])
args = RegressionTrainingArguments(output_dir="./examples")
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
@@ -1400,7 +1402,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
# With a number of elements not a round multiple of the batch size
test_dataset = SampleIterableDataset(length=66)
# Adding one column not used by the model should have no impact
test_dataset = SampleIterableDataset(length=66, label_names=["labels", "extra"])
preds = trainer.predict(test_dataset).predictions
x = test_dataset.dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))