From 8e6c34b390efff20a045f1942732cd20928e90e7 Mon Sep 17 00:00:00 2001 From: Connor Henderson Date: Wed, 22 Mar 2023 06:46:56 -0400 Subject: [PATCH] fix: Allow only test_file in pytorch and flax summarization (#22293) allow only test_file in pytorch and flax summarization --- .../summarization/run_summarization_flax.py | 24 ++++++++++++------- .../summarization/run_summarization.py | 24 ++++++++++++------- 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/examples/flax/summarization/run_summarization_flax.py b/examples/flax/summarization/run_summarization_flax.py index 67f164bc0b..2d7e0acbf5 100644 --- a/examples/flax/summarization/run_summarization_flax.py +++ b/examples/flax/summarization/run_summarization_flax.py @@ -308,8 +308,13 @@ class DataTrainingArguments: ) def __post_init__(self): - if self.dataset_name is None and self.train_file is None and self.validation_file is None: - raise ValueError("Need either a dataset name or a training/validation file.") + if ( + self.dataset_name is None + and self.train_file is None + and self.validation_file is None + and self.test_file is None + ): + raise ValueError("Need either a dataset name or a training, validation, or test file.") else: if self.train_file is not None: extension = self.train_file.split(".")[-1] @@ -317,6 +322,9 @@ class DataTrainingArguments: if self.validation_file is not None: extension = self.validation_file.split(".")[-1] assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + if self.test_file is not None: + extension = self.test_file.split(".")[-1] + assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." if self.val_max_target_length is None: self.val_max_target_length = self.max_target_length @@ -553,10 +561,16 @@ def main(): # Preprocessing the datasets. # We need to tokenize inputs and targets. if training_args.do_train: + if "train" not in dataset: + raise ValueError("--do_train requires a train dataset") column_names = dataset["train"].column_names elif training_args.do_eval: + if "validation" not in dataset: + raise ValueError("--do_eval requires a validation dataset") column_names = dataset["validation"].column_names elif training_args.do_predict: + if "test" not in dataset: + raise ValueError("--do_predict requires a test dataset") column_names = dataset["test"].column_names else: logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") @@ -620,8 +634,6 @@ def main(): return model_inputs if training_args.do_train: - if "train" not in dataset: - raise ValueError("--do_train requires a train dataset") train_dataset = dataset["train"] if data_args.max_train_samples is not None: max_train_samples = min(len(train_dataset), data_args.max_train_samples) @@ -637,8 +649,6 @@ def main(): if training_args.do_eval: max_target_length = data_args.val_max_target_length - if "validation" not in dataset: - raise ValueError("--do_eval requires a validation dataset") eval_dataset = dataset["validation"] if data_args.max_eval_samples is not None: max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) @@ -654,8 +664,6 @@ def main(): if training_args.do_predict: max_target_length = data_args.val_max_target_length - if "test" not in dataset: - raise ValueError("--do_predict requires a test dataset") predict_dataset = dataset["test"] if data_args.max_predict_samples is not None: max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py index e5566f9363..587ff5b770 100755 --- a/examples/pytorch/summarization/run_summarization.py +++ b/examples/pytorch/summarization/run_summarization.py @@ -262,8 +262,13 @@ class DataTrainingArguments: ) def __post_init__(self): - if self.dataset_name is None and self.train_file is None and self.validation_file is None: - raise ValueError("Need either a dataset name or a training/validation file.") + if ( + self.dataset_name is None + and self.train_file is None + and self.validation_file is None + and self.test_file is None + ): + raise ValueError("Need either a dataset name or a training, validation, or test file.") else: if self.train_file is not None: extension = self.train_file.split(".")[-1] @@ -271,6 +276,9 @@ class DataTrainingArguments: if self.validation_file is not None: extension = self.validation_file.split(".")[-1] assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + if self.test_file is not None: + extension = self.test_file.split(".")[-1] + assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." if self.val_max_target_length is None: self.val_max_target_length = self.max_target_length @@ -467,10 +475,16 @@ def main(): # Preprocessing the datasets. # We need to tokenize inputs and targets. if training_args.do_train: + if "train" not in raw_datasets: + raise ValueError("--do_train requires a train dataset") column_names = raw_datasets["train"].column_names elif training_args.do_eval: + if "validation" not in raw_datasets: + raise ValueError("--do_eval requires a validation dataset") column_names = raw_datasets["validation"].column_names elif training_args.do_predict: + if "test" not in raw_datasets: + raise ValueError("--do_predict requires a test dataset") column_names = raw_datasets["test"].column_names else: logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") @@ -546,8 +560,6 @@ def main(): return model_inputs if training_args.do_train: - if "train" not in raw_datasets: - raise ValueError("--do_train requires a train dataset") train_dataset = raw_datasets["train"] if data_args.max_train_samples is not None: max_train_samples = min(len(train_dataset), data_args.max_train_samples) @@ -564,8 +576,6 @@ def main(): if training_args.do_eval: max_target_length = data_args.val_max_target_length - if "validation" not in raw_datasets: - raise ValueError("--do_eval requires a validation dataset") eval_dataset = raw_datasets["validation"] if data_args.max_eval_samples is not None: max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) @@ -582,8 +592,6 @@ def main(): if training_args.do_predict: max_target_length = data_args.val_max_target_length - if "test" not in raw_datasets: - raise ValueError("--do_predict requires a test dataset") predict_dataset = raw_datasets["test"] if data_args.max_predict_samples is not None: max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)