add dataset_name to data_args and added accuracy metric (#11760)
* add `dataset_name` to data_args and added accuracy metric * added documentation for dataset_name * spelling correction
This commit is contained in:
@@ -22,8 +22,8 @@ Based on the script [`run_glue.py`](https://github.com/huggingface/transformers/
|
|||||||
|
|
||||||
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
|
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
|
||||||
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models)
|
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models)
|
||||||
and can also be used for your own data in a csv or a JSON file (the script might need some tweaks in that case, refer
|
and can also be used for a dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file
|
||||||
to the comments inside for help).
|
(the script might need some tweaks in that case, refer to the comments inside for help).
|
||||||
|
|
||||||
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
|
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
|
||||||
|
|
||||||
@@ -64,6 +64,22 @@ single Titan RTX was used):
|
|||||||
Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the
|
Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the
|
||||||
website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website.
|
website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website.
|
||||||
|
|
||||||
|
The following example fine-tunes BERT on the `imdb` dataset hosted on our [hub](https://huggingface.co/datasets):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python run_glue.py \
|
||||||
|
--model_name_or_path bert-base-cased \
|
||||||
|
--dataset_name imdb \
|
||||||
|
--do_train \
|
||||||
|
--do_predict \
|
||||||
|
--max_seq_length 128 \
|
||||||
|
--per_device_train_batch_size 32 \
|
||||||
|
--learning_rate 2e-5 \
|
||||||
|
--num_train_epochs 3 \
|
||||||
|
--output_dir /tmp/imdb/
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### Mixed precision training
|
### Mixed precision training
|
||||||
|
|
||||||
If you have a GPU with mixed precision capabilities (architecture Pascal or more recent), you can use mixed precision
|
If you have a GPU with mixed precision capabilities (architecture Pascal or more recent), you can use mixed precision
|
||||||
|
|||||||
@@ -76,6 +76,12 @@ class DataTrainingArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
|
metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
|
||||||
)
|
)
|
||||||
|
dataset_name: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||||
|
)
|
||||||
|
dataset_config_name: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||||
|
)
|
||||||
max_seq_length: int = field(
|
max_seq_length: int = field(
|
||||||
default=128,
|
default=128,
|
||||||
metadata={
|
metadata={
|
||||||
@@ -127,8 +133,10 @@ class DataTrainingArguments:
|
|||||||
self.task_name = self.task_name.lower()
|
self.task_name = self.task_name.lower()
|
||||||
if self.task_name not in task_to_keys.keys():
|
if self.task_name not in task_to_keys.keys():
|
||||||
raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys()))
|
raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys()))
|
||||||
|
elif self.dataset_name is not None:
|
||||||
|
pass
|
||||||
elif self.train_file is None or self.validation_file is None:
|
elif self.train_file is None or self.validation_file is None:
|
||||||
raise ValueError("Need either a GLUE task or a training/validation file.")
|
raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.")
|
||||||
else:
|
else:
|
||||||
train_extension = self.train_file.split(".")[-1]
|
train_extension = self.train_file.split(".")[-1]
|
||||||
assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
||||||
@@ -240,6 +248,9 @@ def main():
|
|||||||
if data_args.task_name is not None:
|
if data_args.task_name is not None:
|
||||||
# Downloading and loading a dataset from the hub.
|
# Downloading and loading a dataset from the hub.
|
||||||
datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir)
|
datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir)
|
||||||
|
elif data_args.dataset_name is not None:
|
||||||
|
# Downloading and loading a dataset from the hub.
|
||||||
|
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
|
||||||
else:
|
else:
|
||||||
# Loading a dataset from your local files.
|
# Loading a dataset from your local files.
|
||||||
# CSV/JSON training and evaluation files are needed.
|
# CSV/JSON training and evaluation files are needed.
|
||||||
@@ -408,8 +419,8 @@ def main():
|
|||||||
# Get the metric function
|
# Get the metric function
|
||||||
if data_args.task_name is not None:
|
if data_args.task_name is not None:
|
||||||
metric = load_metric("glue", data_args.task_name)
|
metric = load_metric("glue", data_args.task_name)
|
||||||
# TODO: When datasets metrics include regular accuracy, make an else here and remove special branch from
|
else:
|
||||||
# compute_metrics
|
metric = load_metric("accuracy")
|
||||||
|
|
||||||
# You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
|
# You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
|
||||||
# predictions and label_ids field) and has to return a dictionary string to float.
|
# predictions and label_ids field) and has to return a dictionary string to float.
|
||||||
|
|||||||
Reference in New Issue
Block a user