Remove unnecessary columns for all dataset types in Trainer (#17166)
* Remove unneeded columns for IterableDataset * Add test * Update trainer tests * Edit docstring * Lint * Apply feedback * Apply feedback
This commit is contained in:
@@ -109,6 +109,7 @@ from .trainer_utils import (
|
||||
HubStrategy,
|
||||
IntervalStrategy,
|
||||
PredictionOutput,
|
||||
RemoveColumnsCollator,
|
||||
ShardedDDPOption,
|
||||
TrainerMemoryTracker,
|
||||
TrainOutput,
|
||||
@@ -601,27 +602,30 @@ class Trainer:
|
||||
if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
|
||||
model.tie_weights()
|
||||
|
||||
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
|
||||
if not self.args.remove_unused_columns:
|
||||
return dataset
|
||||
def _set_signature_columns_if_needed(self):
|
||||
if self._signature_columns is None:
|
||||
# Inspect model forward signature to keep only the arguments it accepts.
|
||||
signature = inspect.signature(self.model.forward)
|
||||
self._signature_columns = list(signature.parameters.keys())
|
||||
# Labels may be named label or label_ids, the default data collator handles that.
|
||||
self._signature_columns += ["label", "label_ids"]
|
||||
|
||||
ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
|
||||
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
|
||||
if not self.args.remove_unused_columns:
|
||||
return dataset
|
||||
self._set_signature_columns_if_needed()
|
||||
# Labels may be named label or label_ids, the default data collator handles that.
|
||||
signature_columns = self._signature_columns + ["label", "label_ids"]
|
||||
|
||||
ignored_columns = list(set(dataset.column_names) - set(signature_columns))
|
||||
if len(ignored_columns) > 0:
|
||||
dset_description = "" if description is None else f"in the {description} set"
|
||||
logger.info(
|
||||
f"The following columns {dset_description} don't have a corresponding argument in "
|
||||
f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
|
||||
f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
|
||||
f" you can safely ignore this message."
|
||||
" you can safely ignore this message."
|
||||
)
|
||||
|
||||
columns = [k for k in self._signature_columns if k in dataset.column_names]
|
||||
columns = [k for k in signature_columns if k in dataset.column_names]
|
||||
|
||||
if version.parse(datasets.__version__) < version.parse("1.4.0"):
|
||||
dataset.set_format(
|
||||
@@ -631,6 +635,24 @@ class Trainer:
|
||||
else:
|
||||
return dataset.remove_columns(ignored_columns)
|
||||
|
||||
def _get_collator_with_removed_columns(
|
||||
self, data_collator: Callable, description: Optional[str] = None
|
||||
) -> Callable:
|
||||
"""Wrap the data collator in a callable removing unused columns."""
|
||||
if not self.args.remove_unused_columns:
|
||||
return data_collator
|
||||
self._set_signature_columns_if_needed()
|
||||
signature_columns = self._signature_columns + self.label_names
|
||||
|
||||
remove_columns_collator = RemoveColumnsCollator(
|
||||
data_collator=data_collator,
|
||||
signature_columns=signature_columns,
|
||||
logger=logger,
|
||||
description=description,
|
||||
model_name=self.model.__class__.__name__,
|
||||
)
|
||||
return remove_columns_collator
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.train_dataset is None or not has_length(self.train_dataset):
|
||||
return None
|
||||
@@ -717,8 +739,11 @@ class Trainer:
|
||||
raise ValueError("Trainer: training requires a train_dataset.")
|
||||
|
||||
train_dataset = self.train_dataset
|
||||
data_collator = self.data_collator
|
||||
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
||||
train_dataset = self._remove_unused_columns(train_dataset, description="training")
|
||||
else:
|
||||
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
|
||||
|
||||
if isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||
if self.args.world_size > 1:
|
||||
@@ -733,7 +758,7 @@ class Trainer:
|
||||
return DataLoader(
|
||||
train_dataset,
|
||||
batch_size=self.args.per_device_train_batch_size,
|
||||
collate_fn=self.data_collator,
|
||||
collate_fn=data_collator,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
pin_memory=self.args.dataloader_pin_memory,
|
||||
)
|
||||
@@ -744,7 +769,7 @@ class Trainer:
|
||||
train_dataset,
|
||||
batch_size=self._train_batch_size,
|
||||
sampler=train_sampler,
|
||||
collate_fn=self.data_collator,
|
||||
collate_fn=data_collator,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
pin_memory=self.args.dataloader_pin_memory,
|
||||
@@ -794,9 +819,12 @@ class Trainer:
|
||||
if eval_dataset is None and self.eval_dataset is None:
|
||||
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
data_collator = self.data_collator
|
||||
|
||||
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
|
||||
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
|
||||
else:
|
||||
data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
|
||||
|
||||
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
||||
if self.args.world_size > 1:
|
||||
@@ -810,7 +838,7 @@ class Trainer:
|
||||
return DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
collate_fn=self.data_collator,
|
||||
collate_fn=data_collator,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
pin_memory=self.args.dataloader_pin_memory,
|
||||
)
|
||||
@@ -821,7 +849,7 @@ class Trainer:
|
||||
eval_dataset,
|
||||
sampler=eval_sampler,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
collate_fn=self.data_collator,
|
||||
collate_fn=data_collator,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
pin_memory=self.args.dataloader_pin_memory,
|
||||
@@ -838,8 +866,12 @@ class Trainer:
|
||||
The test dataset to use. If it is an `datasets.Dataset`, columns not accepted by the `model.forward()`
|
||||
method are automatically removed. It must implement `__len__`.
|
||||
"""
|
||||
data_collator = self.data_collator
|
||||
|
||||
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
|
||||
test_dataset = self._remove_unused_columns(test_dataset, description="test")
|
||||
else:
|
||||
data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
|
||||
|
||||
if isinstance(test_dataset, torch.utils.data.IterableDataset):
|
||||
if self.args.world_size > 1:
|
||||
@@ -853,7 +885,7 @@ class Trainer:
|
||||
return DataLoader(
|
||||
test_dataset,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
collate_fn=self.data_collator,
|
||||
collate_fn=data_collator,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
pin_memory=self.args.dataloader_pin_memory,
|
||||
)
|
||||
@@ -865,7 +897,7 @@ class Trainer:
|
||||
test_dataset,
|
||||
sampler=test_sampler,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
collate_fn=self.data_collator,
|
||||
collate_fn=data_collator,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
pin_memory=self.args.dataloader_pin_memory,
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@ import random
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -655,3 +655,39 @@ class FSDPOption(ExplicitEnum):
|
||||
SHARD_GRAD_OP = "shard_grad_op"
|
||||
OFFLOAD = "offload"
|
||||
AUTO_WRAP = "auto_wrap"
|
||||
|
||||
|
||||
class RemoveColumnsCollator:
|
||||
"""Wrap the data collator to remove unused columns from its output."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_collator,
|
||||
signature_columns,
|
||||
logger=None,
|
||||
model_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
self.data_collator = data_collator
|
||||
self.signature_columns = signature_columns
|
||||
self.logger = logger
|
||||
self.description = description
|
||||
self.model_name = model_name
|
||||
self.message_logged = False
|
||||
|
||||
def _remove_columns(self, feature: dict) -> dict:
|
||||
if not self.message_logged and self.logger and self.model_name:
|
||||
ignored_columns = list(set(feature.keys()) - set(self.signature_columns))
|
||||
if len(ignored_columns) > 0:
|
||||
dset_description = "" if self.description is None else f"in the {self.description} set"
|
||||
self.logger.info(
|
||||
f"The following columns {dset_description} don't have a corresponding argument in "
|
||||
f"`{self.model_name}.forward` and have been ignored: {', '.join(ignored_columns)}."
|
||||
f" If {', '.join(ignored_columns)} are not expected by `{self.model_name}.forward`, "
|
||||
" you can safely ignore this message."
|
||||
)
|
||||
self.message_logged = True
|
||||
return {k: v for k, v in feature.items() if k in self.signature_columns}
|
||||
|
||||
def __call__(self, features: List[dict]):
|
||||
return self._remove_columns(self.data_collator(features))
|
||||
|
||||
@@ -289,8 +289,7 @@ class TrainingArguments:
|
||||
[`~notebook.NotebookTrainingTracker`] in Jupyter Notebooks. Will default to `True` if the logging level is
|
||||
set to warn or lower (default), `False` otherwise.
|
||||
remove_unused_columns (`bool`, *optional*, defaults to `True`):
|
||||
If using `datasets.Dataset` datasets, whether or not to automatically remove the columns unused by the
|
||||
model forward method.
|
||||
Whether or not to automatically remove the columns unused by the model forward method.
|
||||
|
||||
(Note that this behavior is not implemented for [`TFTrainer`] yet.)
|
||||
label_names (`List[str]`, *optional*):
|
||||
|
||||
@@ -1329,7 +1329,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
def test_training_iterable_dataset(self):
|
||||
config = RegressionModelConfig()
|
||||
model = RegressionPreTrainedModel(config)
|
||||
train_dataset = SampleIterableDataset()
|
||||
# Adding one column not used by the model should have no impact
|
||||
train_dataset = SampleIterableDataset(label_names=["labels", "extra"])
|
||||
|
||||
args = RegressionTrainingArguments(output_dir="./examples", max_steps=4)
|
||||
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
|
||||
@@ -1363,7 +1364,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
def test_evaluation_iterable_dataset(self):
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
# Adding one column not used by the model should have no impact
|
||||
eval_dataset = SampleIterableDataset(label_names=["labels", "extra"])
|
||||
|
||||
args = RegressionTrainingArguments(output_dir="./examples")
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
|
||||
@@ -1400,7 +1402,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
|
||||
|
||||
# With a number of elements not a round multiple of the batch size
|
||||
test_dataset = SampleIterableDataset(length=66)
|
||||
# Adding one column not used by the model should have no impact
|
||||
test_dataset = SampleIterableDataset(length=66, label_names=["labels", "extra"])
|
||||
preds = trainer.predict(test_dataset).predictions
|
||||
x = test_dataset.dataset.x
|
||||
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
|
||||
|
||||
@@ -18,8 +18,9 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.data.data_collator import default_data_collator
|
||||
from transformers.testing_utils import require_accelerate, require_torch
|
||||
from transformers.trainer_utils import find_executable_batch_size
|
||||
from transformers.trainer_utils import RemoveColumnsCollator, find_executable_batch_size
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
|
||||
@@ -457,3 +458,28 @@ class TrainerUtilsTest(unittest.TestCase):
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
mock_training_loop_function()
|
||||
self.assertEqual("CUDA out of memory", cm.args[0])
|
||||
|
||||
def test_remove_columns_collator(self):
|
||||
class MockLogger:
|
||||
def __init__(self) -> None:
|
||||
self.called = 0
|
||||
|
||||
def info(self, msg):
|
||||
self.called += 1
|
||||
self.last_msg = msg
|
||||
|
||||
data_batch = [
|
||||
{"col1": 1, "col2": 2, "col3": 3},
|
||||
{"col1": 1, "col2": 2, "col3": 3},
|
||||
]
|
||||
logger = MockLogger()
|
||||
remove_columns_collator = RemoveColumnsCollator(
|
||||
default_data_collator, ["col1", "col2"], logger, "model", "training"
|
||||
)
|
||||
|
||||
self.assertNotIn("col3", remove_columns_collator(data_batch))
|
||||
# check that the logging message is printed out only once
|
||||
remove_columns_collator(data_batch)
|
||||
remove_columns_collator(data_batch)
|
||||
self.assertEqual(logger.called, 1)
|
||||
self.assertIn("col3", logger.last_msg)
|
||||
|
||||
Reference in New Issue
Block a user