From eabad8fd9c8e24e359a022e55e2a46bdd8f50b6f Mon Sep 17 00:00:00 2001 From: Yusuke Mori Date: Wed, 13 Jan 2021 21:48:35 +0900 Subject: [PATCH] Update run_glue for do_predict with local test data (#9442) (#9486) * Update run_glue for do_predict with local test data (#9442) * Update run_glue (#9442): fix comments ('files' to 'a file') * Update run_glue (#9442): reflect the code review * Update run_glue (#9442): auto format * Update run_glue (#9442): reflect the code review --- examples/text-classification/run_glue.py | 48 +++++++++++++++++------- 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index 2a6f0942a6..60d33786b4 100644 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -93,6 +93,7 @@ class DataTrainingArguments: validation_file: Optional[str] = field( default=None, metadata={"help": "A csv or a json file containing the validation data."} ) + test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."}) def __post_init__(self): if self.task_name is not None: @@ -102,10 +103,12 @@ class DataTrainingArguments: elif self.train_file is None or self.validation_file is None: raise ValueError("Need either a GLUE task or a training/validation file.") else: - extension = self.train_file.split(".")[-1] - assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." - extension = self.validation_file.split(".")[-1] - assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + train_extension = self.train_file.split(".")[-1] + assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file." + validation_extension = self.validation_file.split(".")[-1] + assert ( + validation_extension == train_extension + ), "`validation_file` should have the same extension (csv or json) as `train_file`." @dataclass @@ -205,16 +208,33 @@ def main(): if data_args.task_name is not None: # Downloading and loading a dataset from the hub. datasets = load_dataset("glue", data_args.task_name) - elif data_args.train_file.endswith(".csv"): - # Loading a dataset from local csv files - datasets = load_dataset( - "csv", data_files={"train": data_args.train_file, "validation": data_args.validation_file} - ) else: - # Loading a dataset from local json files - datasets = load_dataset( - "json", data_files={"train": data_args.train_file, "validation": data_args.validation_file} - ) + # Loading a dataset from your local files. + # CSV/JSON training and evaluation files are needed. + data_files = {"train": data_args.train_file, "validation": data_args.validation_file} + + # Get the test dataset: you can provide your own CSV/JSON test file (see below) + # when you use `do_predict` without specifying a GLUE benchmark task. + if training_args.do_predict: + if data_args.test_file is not None: + train_extension = data_args.train_file.split(".")[-1] + test_extension = data_args.test_file.split(".")[-1] + assert ( + test_extension == train_extension + ), "`test_file` should have the same extension (csv or json) as `train_file`." + data_files["test"] = data_args.test_file + else: + raise ValueError("Need either a GLUE task or a test file for `do_predict`.") + + for key in data_files.keys(): + logger.info(f"load a local file for {key}: {data_files[key]}") + + if data_args.train_file.endswith(".csv"): + # Loading a dataset from local csv files + datasets = load_dataset("csv", data_files=data_files) + else: + # Loading a dataset from local json files + datasets = load_dataset("json", data_files=data_files) # See more about loading any type of standard or custom dataset at # https://huggingface.co/docs/datasets/loading_datasets.html. @@ -323,7 +343,7 @@ def main(): train_dataset = datasets["train"] eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"] - if data_args.task_name is not None: + if data_args.task_name is not None or data_args.test_file is not None: test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"] # Log a few random samples from the training set: