From 04e25c62863c9b5e7523ad95778a8b5fa709244d Mon Sep 17 00:00:00 2001 From: Philipp Schmid <32632186+philschmid@users.noreply.github.com> Date: Tue, 18 May 2021 16:27:29 +0200 Subject: [PATCH] add `dataset_name` to data_args and added accuracy metric (#11760) * add `dataset_name` to data_args and added accuracy metric * added documentation for dataset_name * spelling correction --- .../pytorch/text-classification/README.md | 20 +++++++++++++++++-- .../pytorch/text-classification/run_glue.py | 17 +++++++++++++--- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/examples/pytorch/text-classification/README.md b/examples/pytorch/text-classification/README.md index fac7b0eb4b..e3fca9e399 100644 --- a/examples/pytorch/text-classification/README.md +++ b/examples/pytorch/text-classification/README.md @@ -22,8 +22,8 @@ Based on the script [`run_glue.py`](https://github.com/huggingface/transformers/ Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models) -and can also be used for your own data in a csv or a JSON file (the script might need some tweaks in that case, refer -to the comments inside for help). +and can also be used for a dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file +(the script might need some tweaks in that case, refer to the comments inside for help). GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them: @@ -64,6 +64,22 @@ single Titan RTX was used): Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website. +The following example fine-tunes BERT on the `imdb` dataset hosted on our [hub](https://huggingface.co/datasets): + +```bash +python run_glue.py \ + --model_name_or_path bert-base-cased \ + --dataset_name imdb \ + --do_train \ + --do_predict \ + --max_seq_length 128 \ + --per_device_train_batch_size 32 \ + --learning_rate 2e-5 \ + --num_train_epochs 3 \ + --output_dir /tmp/imdb/ +``` + + ### Mixed precision training If you have a GPU with mixed precision capabilities (architecture Pascal or more recent), you can use mixed precision diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index 453a488eaf..5953aa6cdc 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -76,6 +76,12 @@ class DataTrainingArguments: default=None, metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, ) + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) max_seq_length: int = field( default=128, metadata={ @@ -127,8 +133,10 @@ class DataTrainingArguments: self.task_name = self.task_name.lower() if self.task_name not in task_to_keys.keys(): raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys())) + elif self.dataset_name is not None: + pass elif self.train_file is None or self.validation_file is None: - raise ValueError("Need either a GLUE task or a training/validation file.") + raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.") else: train_extension = self.train_file.split(".")[-1] assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file." @@ -240,6 +248,9 @@ def main(): if data_args.task_name is not None: # Downloading and loading a dataset from the hub. datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir) + elif data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) else: # Loading a dataset from your local files. # CSV/JSON training and evaluation files are needed. @@ -408,8 +419,8 @@ def main(): # Get the metric function if data_args.task_name is not None: metric = load_metric("glue", data_args.task_name) - # TODO: When datasets metrics include regular accuracy, make an else here and remove special branch from - # compute_metrics + else: + metric = load_metric("accuracy") # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a # predictions and label_ids field) and has to return a dictionary string to float.