Remove unnecessary columns for all dataset types in Trainer (#17166)

* Remove unneeded columns for IterableDataset

* Add test

* Update trainer tests

* Edit docstring

* Lint

* Apply feedback

* Apply feedback
This commit is contained in:
Antoni Baum
2022-05-11 17:11:26 +02:00
committed by GitHub
parent c33f6046c3
commit edcc66d27c
5 changed files with 118 additions and 22 deletions

View File

@@ -18,8 +18,9 @@ import unittest
import numpy as np
from transformers.data.data_collator import default_data_collator
from transformers.testing_utils import require_accelerate, require_torch
from transformers.trainer_utils import find_executable_batch_size
from transformers.trainer_utils import RemoveColumnsCollator, find_executable_batch_size
from transformers.utils import is_torch_available
@@ -457,3 +458,28 @@ class TrainerUtilsTest(unittest.TestCase):
with self.assertRaises(RuntimeError) as cm:
mock_training_loop_function()
self.assertEqual("CUDA out of memory", cm.args[0])
def test_remove_columns_collator(self):
class MockLogger:
def __init__(self) -> None:
self.called = 0
def info(self, msg):
self.called += 1
self.last_msg = msg
data_batch = [
{"col1": 1, "col2": 2, "col3": 3},
{"col1": 1, "col2": 2, "col3": 3},
]
logger = MockLogger()
remove_columns_collator = RemoveColumnsCollator(
default_data_collator, ["col1", "col2"], logger, "model", "training"
)
self.assertNotIn("col3", remove_columns_collator(data_batch))
# check that the logging message is printed out only once
remove_columns_collator(data_batch)
remove_columns_collator(data_batch)
self.assertEqual(logger.called, 1)
self.assertIn("col3", logger.last_msg)