From edcc66d27ca34f0d7f4c1f18e0c671ab9659555a Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 11 May 2022 17:11:26 +0200 Subject: [PATCH] Remove unnecessary columns for all dataset types in `Trainer` (#17166) * Remove unneeded columns for IterableDataset * Add test * Update trainer tests * Edit docstring * Lint * Apply feedback * Apply feedback --- src/transformers/trainer.py | 62 ++++++++++++++++++++++------- src/transformers/trainer_utils.py | 38 +++++++++++++++++- src/transformers/training_args.py | 3 +- tests/trainer/test_trainer.py | 9 +++-- tests/trainer/test_trainer_utils.py | 28 ++++++++++++- 5 files changed, 118 insertions(+), 22 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 9855f29a46..ef13590e23 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -109,6 +109,7 @@ from .trainer_utils import ( HubStrategy, IntervalStrategy, PredictionOutput, + RemoveColumnsCollator, ShardedDDPOption, TrainerMemoryTracker, TrainOutput, @@ -601,27 +602,30 @@ class Trainer: if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"): model.tie_weights() - def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): - if not self.args.remove_unused_columns: - return dataset + def _set_signature_columns_if_needed(self): if self._signature_columns is None: # Inspect model forward signature to keep only the arguments it accepts. signature = inspect.signature(self.model.forward) self._signature_columns = list(signature.parameters.keys()) - # Labels may be named label or label_ids, the default data collator handles that. - self._signature_columns += ["label", "label_ids"] - ignored_columns = list(set(dataset.column_names) - set(self._signature_columns)) + def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): + if not self.args.remove_unused_columns: + return dataset + self._set_signature_columns_if_needed() + # Labels may be named label or label_ids, the default data collator handles that. + signature_columns = self._signature_columns + ["label", "label_ids"] + + ignored_columns = list(set(dataset.column_names) - set(signature_columns)) if len(ignored_columns) > 0: - dset_description = "" if description is None else f"in the {description} set " + dset_description = "" if description is None else f"in the {description} set" logger.info( f"The following columns {dset_description} don't have a corresponding argument in " f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}." f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, " - f" you can safely ignore this message." + " you can safely ignore this message." ) - columns = [k for k in self._signature_columns if k in dataset.column_names] + columns = [k for k in signature_columns if k in dataset.column_names] if version.parse(datasets.__version__) < version.parse("1.4.0"): dataset.set_format( @@ -631,6 +635,24 @@ class Trainer: else: return dataset.remove_columns(ignored_columns) + def _get_collator_with_removed_columns( + self, data_collator: Callable, description: Optional[str] = None + ) -> Callable: + """Wrap the data collator in a callable removing unused columns.""" + if not self.args.remove_unused_columns: + return data_collator + self._set_signature_columns_if_needed() + signature_columns = self._signature_columns + self.label_names + + remove_columns_collator = RemoveColumnsCollator( + data_collator=data_collator, + signature_columns=signature_columns, + logger=logger, + description=description, + model_name=self.model.__class__.__name__, + ) + return remove_columns_collator + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: if self.train_dataset is None or not has_length(self.train_dataset): return None @@ -717,8 +739,11 @@ class Trainer: raise ValueError("Trainer: training requires a train_dataset.") train_dataset = self.train_dataset + data_collator = self.data_collator if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") if isinstance(train_dataset, torch.utils.data.IterableDataset): if self.args.world_size > 1: @@ -733,7 +758,7 @@ class Trainer: return DataLoader( train_dataset, batch_size=self.args.per_device_train_batch_size, - collate_fn=self.data_collator, + collate_fn=data_collator, num_workers=self.args.dataloader_num_workers, pin_memory=self.args.dataloader_pin_memory, ) @@ -744,7 +769,7 @@ class Trainer: train_dataset, batch_size=self._train_batch_size, sampler=train_sampler, - collate_fn=self.data_collator, + collate_fn=data_collator, drop_last=self.args.dataloader_drop_last, num_workers=self.args.dataloader_num_workers, pin_memory=self.args.dataloader_pin_memory, @@ -794,9 +819,12 @@ class Trainer: if eval_dataset is None and self.eval_dataset is None: raise ValueError("Trainer: evaluation requires an eval_dataset.") eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + data_collator = self.data_collator if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation") if isinstance(eval_dataset, torch.utils.data.IterableDataset): if self.args.world_size > 1: @@ -810,7 +838,7 @@ class Trainer: return DataLoader( eval_dataset, batch_size=self.args.eval_batch_size, - collate_fn=self.data_collator, + collate_fn=data_collator, num_workers=self.args.dataloader_num_workers, pin_memory=self.args.dataloader_pin_memory, ) @@ -821,7 +849,7 @@ class Trainer: eval_dataset, sampler=eval_sampler, batch_size=self.args.eval_batch_size, - collate_fn=self.data_collator, + collate_fn=data_collator, drop_last=self.args.dataloader_drop_last, num_workers=self.args.dataloader_num_workers, pin_memory=self.args.dataloader_pin_memory, @@ -838,8 +866,12 @@ class Trainer: The test dataset to use. If it is an `datasets.Dataset`, columns not accepted by the `model.forward()` method are automatically removed. It must implement `__len__`. """ + data_collator = self.data_collator + if is_datasets_available() and isinstance(test_dataset, datasets.Dataset): test_dataset = self._remove_unused_columns(test_dataset, description="test") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="test") if isinstance(test_dataset, torch.utils.data.IterableDataset): if self.args.world_size > 1: @@ -853,7 +885,7 @@ class Trainer: return DataLoader( test_dataset, batch_size=self.args.eval_batch_size, - collate_fn=self.data_collator, + collate_fn=data_collator, num_workers=self.args.dataloader_num_workers, pin_memory=self.args.dataloader_pin_memory, ) @@ -865,7 +897,7 @@ class Trainer: test_dataset, sampler=test_sampler, batch_size=self.args.eval_batch_size, - collate_fn=self.data_collator, + collate_fn=data_collator, drop_last=self.args.dataloader_drop_last, pin_memory=self.args.dataloader_pin_memory, ) diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index d74d0aed9f..8c76efa65c 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -25,7 +25,7 @@ import random import re import threading import time -from typing import Any, Dict, NamedTuple, Optional, Tuple, Union +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union import numpy as np @@ -655,3 +655,39 @@ class FSDPOption(ExplicitEnum): SHARD_GRAD_OP = "shard_grad_op" OFFLOAD = "offload" AUTO_WRAP = "auto_wrap" + + +class RemoveColumnsCollator: + """Wrap the data collator to remove unused columns from its output.""" + + def __init__( + self, + data_collator, + signature_columns, + logger=None, + model_name: Optional[str] = None, + description: Optional[str] = None, + ): + self.data_collator = data_collator + self.signature_columns = signature_columns + self.logger = logger + self.description = description + self.model_name = model_name + self.message_logged = False + + def _remove_columns(self, feature: dict) -> dict: + if not self.message_logged and self.logger and self.model_name: + ignored_columns = list(set(feature.keys()) - set(self.signature_columns)) + if len(ignored_columns) > 0: + dset_description = "" if self.description is None else f"in the {self.description} set" + self.logger.info( + f"The following columns {dset_description} don't have a corresponding argument in " + f"`{self.model_name}.forward` and have been ignored: {', '.join(ignored_columns)}." + f" If {', '.join(ignored_columns)} are not expected by `{self.model_name}.forward`, " + " you can safely ignore this message." + ) + self.message_logged = True + return {k: v for k, v in feature.items() if k in self.signature_columns} + + def __call__(self, features: List[dict]): + return self._remove_columns(self.data_collator(features)) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index cb929ab631..f8b15ebc85 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -289,8 +289,7 @@ class TrainingArguments: [`~notebook.NotebookTrainingTracker`] in Jupyter Notebooks. Will default to `True` if the logging level is set to warn or lower (default), `False` otherwise. remove_unused_columns (`bool`, *optional*, defaults to `True`): - If using `datasets.Dataset` datasets, whether or not to automatically remove the columns unused by the - model forward method. + Whether or not to automatically remove the columns unused by the model forward method. (Note that this behavior is not implemented for [`TFTrainer`] yet.) label_names (`List[str]`, *optional*): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index e5e11fcd21..186d141bc3 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1329,7 +1329,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): def test_training_iterable_dataset(self): config = RegressionModelConfig() model = RegressionPreTrainedModel(config) - train_dataset = SampleIterableDataset() + # Adding one column not used by the model should have no impact + train_dataset = SampleIterableDataset(label_names=["labels", "extra"]) args = RegressionTrainingArguments(output_dir="./examples", max_steps=4) trainer = Trainer(model=model, args=args, train_dataset=train_dataset) @@ -1363,7 +1364,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): def test_evaluation_iterable_dataset(self): config = RegressionModelConfig(a=1.5, b=2.5) model = RegressionPreTrainedModel(config) - eval_dataset = SampleIterableDataset() + # Adding one column not used by the model should have no impact + eval_dataset = SampleIterableDataset(label_names=["labels", "extra"]) args = RegressionTrainingArguments(output_dir="./examples") trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy()) @@ -1400,7 +1402,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertTrue(np.allclose(preds, 1.5 * x + 2.5)) # With a number of elements not a round multiple of the batch size - test_dataset = SampleIterableDataset(length=66) + # Adding one column not used by the model should have no impact + test_dataset = SampleIterableDataset(length=66, label_names=["labels", "extra"]) preds = trainer.predict(test_dataset).predictions x = test_dataset.dataset.x self.assertTrue(np.allclose(preds, 1.5 * x + 2.5)) diff --git a/tests/trainer/test_trainer_utils.py b/tests/trainer/test_trainer_utils.py index 41448fdcb4..168beb95b9 100644 --- a/tests/trainer/test_trainer_utils.py +++ b/tests/trainer/test_trainer_utils.py @@ -18,8 +18,9 @@ import unittest import numpy as np +from transformers.data.data_collator import default_data_collator from transformers.testing_utils import require_accelerate, require_torch -from transformers.trainer_utils import find_executable_batch_size +from transformers.trainer_utils import RemoveColumnsCollator, find_executable_batch_size from transformers.utils import is_torch_available @@ -457,3 +458,28 @@ class TrainerUtilsTest(unittest.TestCase): with self.assertRaises(RuntimeError) as cm: mock_training_loop_function() self.assertEqual("CUDA out of memory", cm.args[0]) + + def test_remove_columns_collator(self): + class MockLogger: + def __init__(self) -> None: + self.called = 0 + + def info(self, msg): + self.called += 1 + self.last_msg = msg + + data_batch = [ + {"col1": 1, "col2": 2, "col3": 3}, + {"col1": 1, "col2": 2, "col3": 3}, + ] + logger = MockLogger() + remove_columns_collator = RemoveColumnsCollator( + default_data_collator, ["col1", "col2"], logger, "model", "training" + ) + + self.assertNotIn("col3", remove_columns_collator(data_batch)) + # check that the logging message is printed out only once + remove_columns_collator(data_batch) + remove_columns_collator(data_batch) + self.assertEqual(logger.called, 1) + self.assertIn("col3", logger.last_msg)