From f1b938fda81d4b9e8ab435cb7f37f71c9b7cbb1e Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 20 Apr 2021 14:12:01 -0400 Subject: [PATCH] Update to use datasets remove_cloumns method (#11343) * Update to use datasets remove_cloumns method * Quality --- examples/question-answering/requirements.txt | 2 +- examples/question-answering/trainer_qa.py | 13 +------- src/transformers/trainer.py | 33 +++++++++++--------- 3 files changed, 21 insertions(+), 27 deletions(-) diff --git a/examples/question-answering/requirements.txt b/examples/question-answering/requirements.txt index c8205f0d3d..5a9f0358d3 100644 --- a/examples/question-answering/requirements.txt +++ b/examples/question-answering/requirements.txt @@ -1 +1 @@ -datasets >= 1.2.1 +datasets >= 1.4.0 diff --git a/examples/question-answering/trainer_qa.py b/examples/question-answering/trainer_qa.py index db7b80c015..41699cd1df 100644 --- a/examples/question-answering/trainer_qa.py +++ b/examples/question-answering/trainer_qa.py @@ -16,13 +16,10 @@ A subclass of `Trainer` specific to Question-Answering tasks """ -from transformers import Trainer, is_datasets_available, is_torch_tpu_available +from transformers import Trainer, is_torch_tpu_available from transformers.trainer_utils import PredictionOutput -if is_datasets_available(): - import datasets - if is_torch_tpu_available(): import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met @@ -54,10 +51,6 @@ class QuestionAnsweringTrainer(Trainer): finally: self.compute_metrics = compute_metrics - # We might have removed columns from the dataset so we put them back. - if isinstance(eval_dataset, datasets.Dataset): - eval_dataset.set_format(type=eval_dataset.format["type"], columns=list(eval_dataset.features.keys())) - if self.post_process_function is not None and self.compute_metrics is not None: eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions) metrics = self.compute_metrics(eval_preds) @@ -94,10 +87,6 @@ class QuestionAnsweringTrainer(Trainer): if self.post_process_function is None or self.compute_metrics is None: return output - # We might have removed columns from the dataset so we put them back. - if isinstance(test_dataset, datasets.Dataset): - test_dataset.set_format(type=test_dataset.format["type"], columns=list(test_dataset.features.keys())) - eval_preds = self.post_process_function(test_examples, test_dataset, output.predictions, "test") metrics = self.compute_metrics(eval_preds) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a0d4440f2f..254f7d8e6e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -394,11 +394,6 @@ class Trainer: raise ValueError("train_dataset does not implement __len__, max_steps has to be specified") self._signature_columns = None - if is_datasets_available(): - if isinstance(train_dataset, datasets.Dataset): - self._remove_unused_columns(self.train_dataset, description="training") - if isinstance(eval_dataset, datasets.Dataset): - self._remove_unused_columns(self.eval_dataset, description="evaluation") # Mixed precision setup self.use_apex = False @@ -503,7 +498,13 @@ class Trainer: f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}." ) - dataset.set_format(type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]) + if version.parse(datasets.__version__) < version.parse("1.4.0"): + dataset.set_format( + type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"] + ) + return dataset + else: + return dataset.remove_columns(ignored_columns) def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: if not isinstance(self.train_dataset, collections.abc.Sized): @@ -565,17 +566,20 @@ class Trainer: if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") - if isinstance(self.train_dataset, torch.utils.data.dataset.IterableDataset): + train_dataset = self.train_dataset + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + + if isinstance(train_dataset, torch.utils.data.dataset.IterableDataset): if self.args.world_size > 1: train_dataset = IterableDatasetShard( - self.train_dataset, + train_dataset, batch_size=self.args.train_batch_size, drop_last=self.args.dataloader_drop_last, num_processes=self.args.world_size, process_index=self.args.process_index, ) - else: - train_dataset = self.train_dataset + return DataLoader( train_dataset, batch_size=self.args.train_batch_size, @@ -587,7 +591,7 @@ class Trainer: train_sampler = self._get_train_sampler() return DataLoader( - self.train_dataset, + train_dataset, batch_size=self.args.train_batch_size, sampler=train_sampler, collate_fn=self.data_collator, @@ -638,10 +642,11 @@ class Trainer: """ if eval_dataset is None and self.eval_dataset is None: raise ValueError("Trainer: evaluation requires an eval_dataset.") - elif is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): - self._remove_unused_columns(eval_dataset, description="evaluation") eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): + eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") + if isinstance(eval_dataset, torch.utils.data.dataset.IterableDataset): if self.args.world_size > 1: eval_dataset = IterableDatasetShard( @@ -683,7 +688,7 @@ class Trainer: ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`. """ if is_datasets_available() and isinstance(test_dataset, datasets.Dataset): - self._remove_unused_columns(test_dataset, description="test") + test_dataset = self._remove_unused_columns(test_dataset, description="test") if isinstance(test_dataset, torch.utils.data.dataset.IterableDataset): if self.args.world_size > 1: