Remove unnecessary columns for all dataset types in Trainer (#17166)
* Remove unneeded columns for IterableDataset * Add test * Update trainer tests * Edit docstring * Lint * Apply feedback * Apply feedback
This commit is contained in:
@@ -109,6 +109,7 @@ from .trainer_utils import (
|
|||||||
HubStrategy,
|
HubStrategy,
|
||||||
IntervalStrategy,
|
IntervalStrategy,
|
||||||
PredictionOutput,
|
PredictionOutput,
|
||||||
|
RemoveColumnsCollator,
|
||||||
ShardedDDPOption,
|
ShardedDDPOption,
|
||||||
TrainerMemoryTracker,
|
TrainerMemoryTracker,
|
||||||
TrainOutput,
|
TrainOutput,
|
||||||
@@ -601,27 +602,30 @@ class Trainer:
|
|||||||
if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
|
if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
|
|
||||||
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
|
def _set_signature_columns_if_needed(self):
|
||||||
if not self.args.remove_unused_columns:
|
|
||||||
return dataset
|
|
||||||
if self._signature_columns is None:
|
if self._signature_columns is None:
|
||||||
# Inspect model forward signature to keep only the arguments it accepts.
|
# Inspect model forward signature to keep only the arguments it accepts.
|
||||||
signature = inspect.signature(self.model.forward)
|
signature = inspect.signature(self.model.forward)
|
||||||
self._signature_columns = list(signature.parameters.keys())
|
self._signature_columns = list(signature.parameters.keys())
|
||||||
# Labels may be named label or label_ids, the default data collator handles that.
|
|
||||||
self._signature_columns += ["label", "label_ids"]
|
|
||||||
|
|
||||||
ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
|
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
|
||||||
|
if not self.args.remove_unused_columns:
|
||||||
|
return dataset
|
||||||
|
self._set_signature_columns_if_needed()
|
||||||
|
# Labels may be named label or label_ids, the default data collator handles that.
|
||||||
|
signature_columns = self._signature_columns + ["label", "label_ids"]
|
||||||
|
|
||||||
|
ignored_columns = list(set(dataset.column_names) - set(signature_columns))
|
||||||
if len(ignored_columns) > 0:
|
if len(ignored_columns) > 0:
|
||||||
dset_description = "" if description is None else f"in the {description} set"
|
dset_description = "" if description is None else f"in the {description} set"
|
||||||
logger.info(
|
logger.info(
|
||||||
f"The following columns {dset_description} don't have a corresponding argument in "
|
f"The following columns {dset_description} don't have a corresponding argument in "
|
||||||
f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
|
f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
|
||||||
f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
|
f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
|
||||||
f" you can safely ignore this message."
|
" you can safely ignore this message."
|
||||||
)
|
)
|
||||||
|
|
||||||
columns = [k for k in self._signature_columns if k in dataset.column_names]
|
columns = [k for k in signature_columns if k in dataset.column_names]
|
||||||
|
|
||||||
if version.parse(datasets.__version__) < version.parse("1.4.0"):
|
if version.parse(datasets.__version__) < version.parse("1.4.0"):
|
||||||
dataset.set_format(
|
dataset.set_format(
|
||||||
@@ -631,6 +635,24 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
return dataset.remove_columns(ignored_columns)
|
return dataset.remove_columns(ignored_columns)
|
||||||
|
|
||||||
|
def _get_collator_with_removed_columns(
|
||||||
|
self, data_collator: Callable, description: Optional[str] = None
|
||||||
|
) -> Callable:
|
||||||
|
"""Wrap the data collator in a callable removing unused columns."""
|
||||||
|
if not self.args.remove_unused_columns:
|
||||||
|
return data_collator
|
||||||
|
self._set_signature_columns_if_needed()
|
||||||
|
signature_columns = self._signature_columns + self.label_names
|
||||||
|
|
||||||
|
remove_columns_collator = RemoveColumnsCollator(
|
||||||
|
data_collator=data_collator,
|
||||||
|
signature_columns=signature_columns,
|
||||||
|
logger=logger,
|
||||||
|
description=description,
|
||||||
|
model_name=self.model.__class__.__name__,
|
||||||
|
)
|
||||||
|
return remove_columns_collator
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
if self.train_dataset is None or not has_length(self.train_dataset):
|
if self.train_dataset is None or not has_length(self.train_dataset):
|
||||||
return None
|
return None
|
||||||
@@ -717,8 +739,11 @@ class Trainer:
|
|||||||
raise ValueError("Trainer: training requires a train_dataset.")
|
raise ValueError("Trainer: training requires a train_dataset.")
|
||||||
|
|
||||||
train_dataset = self.train_dataset
|
train_dataset = self.train_dataset
|
||||||
|
data_collator = self.data_collator
|
||||||
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
||||||
train_dataset = self._remove_unused_columns(train_dataset, description="training")
|
train_dataset = self._remove_unused_columns(train_dataset, description="training")
|
||||||
|
else:
|
||||||
|
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
|
||||||
|
|
||||||
if isinstance(train_dataset, torch.utils.data.IterableDataset):
|
if isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||||
if self.args.world_size > 1:
|
if self.args.world_size > 1:
|
||||||
@@ -733,7 +758,7 @@ class Trainer:
|
|||||||
return DataLoader(
|
return DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=self.args.per_device_train_batch_size,
|
batch_size=self.args.per_device_train_batch_size,
|
||||||
collate_fn=self.data_collator,
|
collate_fn=data_collator,
|
||||||
num_workers=self.args.dataloader_num_workers,
|
num_workers=self.args.dataloader_num_workers,
|
||||||
pin_memory=self.args.dataloader_pin_memory,
|
pin_memory=self.args.dataloader_pin_memory,
|
||||||
)
|
)
|
||||||
@@ -744,7 +769,7 @@ class Trainer:
|
|||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=self._train_batch_size,
|
batch_size=self._train_batch_size,
|
||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
collate_fn=self.data_collator,
|
collate_fn=data_collator,
|
||||||
drop_last=self.args.dataloader_drop_last,
|
drop_last=self.args.dataloader_drop_last,
|
||||||
num_workers=self.args.dataloader_num_workers,
|
num_workers=self.args.dataloader_num_workers,
|
||||||
pin_memory=self.args.dataloader_pin_memory,
|
pin_memory=self.args.dataloader_pin_memory,
|
||||||
@@ -794,9 +819,12 @@ class Trainer:
|
|||||||
if eval_dataset is None and self.eval_dataset is None:
|
if eval_dataset is None and self.eval_dataset is None:
|
||||||
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
||||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
|
data_collator = self.data_collator
|
||||||
|
|
||||||
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
|
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
|
||||||
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
|
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
|
||||||
|
else:
|
||||||
|
data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
|
||||||
|
|
||||||
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
||||||
if self.args.world_size > 1:
|
if self.args.world_size > 1:
|
||||||
@@ -810,7 +838,7 @@ class Trainer:
|
|||||||
return DataLoader(
|
return DataLoader(
|
||||||
eval_dataset,
|
eval_dataset,
|
||||||
batch_size=self.args.eval_batch_size,
|
batch_size=self.args.eval_batch_size,
|
||||||
collate_fn=self.data_collator,
|
collate_fn=data_collator,
|
||||||
num_workers=self.args.dataloader_num_workers,
|
num_workers=self.args.dataloader_num_workers,
|
||||||
pin_memory=self.args.dataloader_pin_memory,
|
pin_memory=self.args.dataloader_pin_memory,
|
||||||
)
|
)
|
||||||
@@ -821,7 +849,7 @@ class Trainer:
|
|||||||
eval_dataset,
|
eval_dataset,
|
||||||
sampler=eval_sampler,
|
sampler=eval_sampler,
|
||||||
batch_size=self.args.eval_batch_size,
|
batch_size=self.args.eval_batch_size,
|
||||||
collate_fn=self.data_collator,
|
collate_fn=data_collator,
|
||||||
drop_last=self.args.dataloader_drop_last,
|
drop_last=self.args.dataloader_drop_last,
|
||||||
num_workers=self.args.dataloader_num_workers,
|
num_workers=self.args.dataloader_num_workers,
|
||||||
pin_memory=self.args.dataloader_pin_memory,
|
pin_memory=self.args.dataloader_pin_memory,
|
||||||
@@ -838,8 +866,12 @@ class Trainer:
|
|||||||
The test dataset to use. If it is an `datasets.Dataset`, columns not accepted by the `model.forward()`
|
The test dataset to use. If it is an `datasets.Dataset`, columns not accepted by the `model.forward()`
|
||||||
method are automatically removed. It must implement `__len__`.
|
method are automatically removed. It must implement `__len__`.
|
||||||
"""
|
"""
|
||||||
|
data_collator = self.data_collator
|
||||||
|
|
||||||
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
|
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
|
||||||
test_dataset = self._remove_unused_columns(test_dataset, description="test")
|
test_dataset = self._remove_unused_columns(test_dataset, description="test")
|
||||||
|
else:
|
||||||
|
data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
|
||||||
|
|
||||||
if isinstance(test_dataset, torch.utils.data.IterableDataset):
|
if isinstance(test_dataset, torch.utils.data.IterableDataset):
|
||||||
if self.args.world_size > 1:
|
if self.args.world_size > 1:
|
||||||
@@ -853,7 +885,7 @@ class Trainer:
|
|||||||
return DataLoader(
|
return DataLoader(
|
||||||
test_dataset,
|
test_dataset,
|
||||||
batch_size=self.args.eval_batch_size,
|
batch_size=self.args.eval_batch_size,
|
||||||
collate_fn=self.data_collator,
|
collate_fn=data_collator,
|
||||||
num_workers=self.args.dataloader_num_workers,
|
num_workers=self.args.dataloader_num_workers,
|
||||||
pin_memory=self.args.dataloader_pin_memory,
|
pin_memory=self.args.dataloader_pin_memory,
|
||||||
)
|
)
|
||||||
@@ -865,7 +897,7 @@ class Trainer:
|
|||||||
test_dataset,
|
test_dataset,
|
||||||
sampler=test_sampler,
|
sampler=test_sampler,
|
||||||
batch_size=self.args.eval_batch_size,
|
batch_size=self.args.eval_batch_size,
|
||||||
collate_fn=self.data_collator,
|
collate_fn=data_collator,
|
||||||
drop_last=self.args.dataloader_drop_last,
|
drop_last=self.args.dataloader_drop_last,
|
||||||
pin_memory=self.args.dataloader_pin_memory,
|
pin_memory=self.args.dataloader_pin_memory,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ import random
|
|||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
|
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -655,3 +655,39 @@ class FSDPOption(ExplicitEnum):
|
|||||||
SHARD_GRAD_OP = "shard_grad_op"
|
SHARD_GRAD_OP = "shard_grad_op"
|
||||||
OFFLOAD = "offload"
|
OFFLOAD = "offload"
|
||||||
AUTO_WRAP = "auto_wrap"
|
AUTO_WRAP = "auto_wrap"
|
||||||
|
|
||||||
|
|
||||||
|
class RemoveColumnsCollator:
|
||||||
|
"""Wrap the data collator to remove unused columns from its output."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_collator,
|
||||||
|
signature_columns,
|
||||||
|
logger=None,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.data_collator = data_collator
|
||||||
|
self.signature_columns = signature_columns
|
||||||
|
self.logger = logger
|
||||||
|
self.description = description
|
||||||
|
self.model_name = model_name
|
||||||
|
self.message_logged = False
|
||||||
|
|
||||||
|
def _remove_columns(self, feature: dict) -> dict:
|
||||||
|
if not self.message_logged and self.logger and self.model_name:
|
||||||
|
ignored_columns = list(set(feature.keys()) - set(self.signature_columns))
|
||||||
|
if len(ignored_columns) > 0:
|
||||||
|
dset_description = "" if self.description is None else f"in the {self.description} set"
|
||||||
|
self.logger.info(
|
||||||
|
f"The following columns {dset_description} don't have a corresponding argument in "
|
||||||
|
f"`{self.model_name}.forward` and have been ignored: {', '.join(ignored_columns)}."
|
||||||
|
f" If {', '.join(ignored_columns)} are not expected by `{self.model_name}.forward`, "
|
||||||
|
" you can safely ignore this message."
|
||||||
|
)
|
||||||
|
self.message_logged = True
|
||||||
|
return {k: v for k, v in feature.items() if k in self.signature_columns}
|
||||||
|
|
||||||
|
def __call__(self, features: List[dict]):
|
||||||
|
return self._remove_columns(self.data_collator(features))
|
||||||
|
|||||||
@@ -289,8 +289,7 @@ class TrainingArguments:
|
|||||||
[`~notebook.NotebookTrainingTracker`] in Jupyter Notebooks. Will default to `True` if the logging level is
|
[`~notebook.NotebookTrainingTracker`] in Jupyter Notebooks. Will default to `True` if the logging level is
|
||||||
set to warn or lower (default), `False` otherwise.
|
set to warn or lower (default), `False` otherwise.
|
||||||
remove_unused_columns (`bool`, *optional*, defaults to `True`):
|
remove_unused_columns (`bool`, *optional*, defaults to `True`):
|
||||||
If using `datasets.Dataset` datasets, whether or not to automatically remove the columns unused by the
|
Whether or not to automatically remove the columns unused by the model forward method.
|
||||||
model forward method.
|
|
||||||
|
|
||||||
(Note that this behavior is not implemented for [`TFTrainer`] yet.)
|
(Note that this behavior is not implemented for [`TFTrainer`] yet.)
|
||||||
label_names (`List[str]`, *optional*):
|
label_names (`List[str]`, *optional*):
|
||||||
|
|||||||
@@ -1329,7 +1329,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
def test_training_iterable_dataset(self):
|
def test_training_iterable_dataset(self):
|
||||||
config = RegressionModelConfig()
|
config = RegressionModelConfig()
|
||||||
model = RegressionPreTrainedModel(config)
|
model = RegressionPreTrainedModel(config)
|
||||||
train_dataset = SampleIterableDataset()
|
# Adding one column not used by the model should have no impact
|
||||||
|
train_dataset = SampleIterableDataset(label_names=["labels", "extra"])
|
||||||
|
|
||||||
args = RegressionTrainingArguments(output_dir="./examples", max_steps=4)
|
args = RegressionTrainingArguments(output_dir="./examples", max_steps=4)
|
||||||
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
|
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
|
||||||
@@ -1363,7 +1364,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
def test_evaluation_iterable_dataset(self):
|
def test_evaluation_iterable_dataset(self):
|
||||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||||
model = RegressionPreTrainedModel(config)
|
model = RegressionPreTrainedModel(config)
|
||||||
eval_dataset = SampleIterableDataset()
|
# Adding one column not used by the model should have no impact
|
||||||
|
eval_dataset = SampleIterableDataset(label_names=["labels", "extra"])
|
||||||
|
|
||||||
args = RegressionTrainingArguments(output_dir="./examples")
|
args = RegressionTrainingArguments(output_dir="./examples")
|
||||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
|
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
|
||||||
@@ -1400,7 +1402,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
|
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
|
||||||
|
|
||||||
# With a number of elements not a round multiple of the batch size
|
# With a number of elements not a round multiple of the batch size
|
||||||
test_dataset = SampleIterableDataset(length=66)
|
# Adding one column not used by the model should have no impact
|
||||||
|
test_dataset = SampleIterableDataset(length=66, label_names=["labels", "extra"])
|
||||||
preds = trainer.predict(test_dataset).predictions
|
preds = trainer.predict(test_dataset).predictions
|
||||||
x = test_dataset.dataset.x
|
x = test_dataset.dataset.x
|
||||||
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
|
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
|
||||||
|
|||||||
@@ -18,8 +18,9 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.data.data_collator import default_data_collator
|
||||||
from transformers.testing_utils import require_accelerate, require_torch
|
from transformers.testing_utils import require_accelerate, require_torch
|
||||||
from transformers.trainer_utils import find_executable_batch_size
|
from transformers.trainer_utils import RemoveColumnsCollator, find_executable_batch_size
|
||||||
from transformers.utils import is_torch_available
|
from transformers.utils import is_torch_available
|
||||||
|
|
||||||
|
|
||||||
@@ -457,3 +458,28 @@ class TrainerUtilsTest(unittest.TestCase):
|
|||||||
with self.assertRaises(RuntimeError) as cm:
|
with self.assertRaises(RuntimeError) as cm:
|
||||||
mock_training_loop_function()
|
mock_training_loop_function()
|
||||||
self.assertEqual("CUDA out of memory", cm.args[0])
|
self.assertEqual("CUDA out of memory", cm.args[0])
|
||||||
|
|
||||||
|
def test_remove_columns_collator(self):
|
||||||
|
class MockLogger:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.called = 0
|
||||||
|
|
||||||
|
def info(self, msg):
|
||||||
|
self.called += 1
|
||||||
|
self.last_msg = msg
|
||||||
|
|
||||||
|
data_batch = [
|
||||||
|
{"col1": 1, "col2": 2, "col3": 3},
|
||||||
|
{"col1": 1, "col2": 2, "col3": 3},
|
||||||
|
]
|
||||||
|
logger = MockLogger()
|
||||||
|
remove_columns_collator = RemoveColumnsCollator(
|
||||||
|
default_data_collator, ["col1", "col2"], logger, "model", "training"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertNotIn("col3", remove_columns_collator(data_batch))
|
||||||
|
# check that the logging message is printed out only once
|
||||||
|
remove_columns_collator(data_batch)
|
||||||
|
remove_columns_collator(data_batch)
|
||||||
|
self.assertEqual(logger.called, 1)
|
||||||
|
self.assertIn("col3", logger.last_msg)
|
||||||
|
|||||||
Reference in New Issue
Block a user