Update to use datasets remove_cloumns method (#11343)

* Update to use datasets remove_cloumns method

* Quality
This commit is contained in:
Sylvain Gugger
2021-04-20 14:12:01 -04:00
committed by GitHub
parent cfd2eaa8cf
commit f1b938fda8
3 changed files with 21 additions and 27 deletions

View File

@@ -1 +1 @@
datasets >= 1.2.1
datasets >= 1.4.0

View File

@@ -16,13 +16,10 @@
A subclass of `Trainer` specific to Question-Answering tasks
"""
from transformers import Trainer, is_datasets_available, is_torch_tpu_available
from transformers import Trainer, is_torch_tpu_available
from transformers.trainer_utils import PredictionOutput
if is_datasets_available():
import datasets
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
@@ -54,10 +51,6 @@ class QuestionAnsweringTrainer(Trainer):
finally:
self.compute_metrics = compute_metrics
# We might have removed columns from the dataset so we put them back.
if isinstance(eval_dataset, datasets.Dataset):
eval_dataset.set_format(type=eval_dataset.format["type"], columns=list(eval_dataset.features.keys()))
if self.post_process_function is not None and self.compute_metrics is not None:
eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions)
metrics = self.compute_metrics(eval_preds)
@@ -94,10 +87,6 @@ class QuestionAnsweringTrainer(Trainer):
if self.post_process_function is None or self.compute_metrics is None:
return output
# We might have removed columns from the dataset so we put them back.
if isinstance(test_dataset, datasets.Dataset):
test_dataset.set_format(type=test_dataset.format["type"], columns=list(test_dataset.features.keys()))
eval_preds = self.post_process_function(test_examples, test_dataset, output.predictions, "test")
metrics = self.compute_metrics(eval_preds)