add dataset_name to data_args and added accuracy metric (#11760)
* add `dataset_name` to data_args and added accuracy metric * added documentation for dataset_name * spelling correction
This commit is contained in:
@@ -76,6 +76,12 @@ class DataTrainingArguments:
|
||||
default=None,
|
||||
metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
|
||||
)
|
||||
dataset_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
@@ -127,8 +133,10 @@ class DataTrainingArguments:
|
||||
self.task_name = self.task_name.lower()
|
||||
if self.task_name not in task_to_keys.keys():
|
||||
raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys()))
|
||||
elif self.dataset_name is not None:
|
||||
pass
|
||||
elif self.train_file is None or self.validation_file is None:
|
||||
raise ValueError("Need either a GLUE task or a training/validation file.")
|
||||
raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.")
|
||||
else:
|
||||
train_extension = self.train_file.split(".")[-1]
|
||||
assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
||||
@@ -240,6 +248,9 @@ def main():
|
||||
if data_args.task_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir)
|
||||
elif data_args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
|
||||
else:
|
||||
# Loading a dataset from your local files.
|
||||
# CSV/JSON training and evaluation files are needed.
|
||||
@@ -408,8 +419,8 @@ def main():
|
||||
# Get the metric function
|
||||
if data_args.task_name is not None:
|
||||
metric = load_metric("glue", data_args.task_name)
|
||||
# TODO: When datasets metrics include regular accuracy, make an else here and remove special branch from
|
||||
# compute_metrics
|
||||
else:
|
||||
metric = load_metric("accuracy")
|
||||
|
||||
# You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
|
||||
# predictions and label_ids field) and has to return a dictionary string to float.
|
||||
|
||||
Reference in New Issue
Block a user