Remove unnecessary columns for all dataset types in Trainer (#17166)
* Remove unneeded columns for IterableDataset * Add test * Update trainer tests * Edit docstring * Lint * Apply feedback * Apply feedback
This commit is contained in:
@@ -1329,7 +1329,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
def test_training_iterable_dataset(self):
|
||||
config = RegressionModelConfig()
|
||||
model = RegressionPreTrainedModel(config)
|
||||
train_dataset = SampleIterableDataset()
|
||||
# Adding one column not used by the model should have no impact
|
||||
train_dataset = SampleIterableDataset(label_names=["labels", "extra"])
|
||||
|
||||
args = RegressionTrainingArguments(output_dir="./examples", max_steps=4)
|
||||
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
|
||||
@@ -1363,7 +1364,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
def test_evaluation_iterable_dataset(self):
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
# Adding one column not used by the model should have no impact
|
||||
eval_dataset = SampleIterableDataset(label_names=["labels", "extra"])
|
||||
|
||||
args = RegressionTrainingArguments(output_dir="./examples")
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
|
||||
@@ -1400,7 +1402,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
|
||||
|
||||
# With a number of elements not a round multiple of the batch size
|
||||
test_dataset = SampleIterableDataset(length=66)
|
||||
# Adding one column not used by the model should have no impact
|
||||
test_dataset = SampleIterableDataset(length=66, label_names=["labels", "extra"])
|
||||
preds = trainer.predict(test_dataset).predictions
|
||||
x = test_dataset.dataset.x
|
||||
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
|
||||
|
||||
@@ -18,8 +18,9 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.data.data_collator import default_data_collator
|
||||
from transformers.testing_utils import require_accelerate, require_torch
|
||||
from transformers.trainer_utils import find_executable_batch_size
|
||||
from transformers.trainer_utils import RemoveColumnsCollator, find_executable_batch_size
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
|
||||
@@ -457,3 +458,28 @@ class TrainerUtilsTest(unittest.TestCase):
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
mock_training_loop_function()
|
||||
self.assertEqual("CUDA out of memory", cm.args[0])
|
||||
|
||||
def test_remove_columns_collator(self):
|
||||
class MockLogger:
|
||||
def __init__(self) -> None:
|
||||
self.called = 0
|
||||
|
||||
def info(self, msg):
|
||||
self.called += 1
|
||||
self.last_msg = msg
|
||||
|
||||
data_batch = [
|
||||
{"col1": 1, "col2": 2, "col3": 3},
|
||||
{"col1": 1, "col2": 2, "col3": 3},
|
||||
]
|
||||
logger = MockLogger()
|
||||
remove_columns_collator = RemoveColumnsCollator(
|
||||
default_data_collator, ["col1", "col2"], logger, "model", "training"
|
||||
)
|
||||
|
||||
self.assertNotIn("col3", remove_columns_collator(data_batch))
|
||||
# check that the logging message is printed out only once
|
||||
remove_columns_collator(data_batch)
|
||||
remove_columns_collator(data_batch)
|
||||
self.assertEqual(logger.called, 1)
|
||||
self.assertIn("col3", logger.last_msg)
|
||||
|
||||
Reference in New Issue
Block a user