fix: Allow only test_file in pytorch and flax summarization (#22293)
allow only test_file in pytorch and flax summarization
This commit is contained in:
@@ -308,8 +308,13 @@ class DataTrainingArguments:
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
if (
|
||||
self.dataset_name is None
|
||||
and self.train_file is None
|
||||
and self.validation_file is None
|
||||
and self.test_file is None
|
||||
):
|
||||
raise ValueError("Need either a dataset name or a training, validation, or test file.")
|
||||
else:
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
@@ -317,6 +322,9 @@ class DataTrainingArguments:
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||
if self.test_file is not None:
|
||||
extension = self.test_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
|
||||
if self.val_max_target_length is None:
|
||||
self.val_max_target_length = self.max_target_length
|
||||
|
||||
@@ -553,10 +561,16 @@ def main():
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize inputs and targets.
|
||||
if training_args.do_train:
|
||||
if "train" not in dataset:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
column_names = dataset["train"].column_names
|
||||
elif training_args.do_eval:
|
||||
if "validation" not in dataset:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
column_names = dataset["validation"].column_names
|
||||
elif training_args.do_predict:
|
||||
if "test" not in dataset:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
column_names = dataset["test"].column_names
|
||||
else:
|
||||
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
||||
@@ -620,8 +634,6 @@ def main():
|
||||
return model_inputs
|
||||
|
||||
if training_args.do_train:
|
||||
if "train" not in dataset:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
train_dataset = dataset["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
||||
@@ -637,8 +649,6 @@ def main():
|
||||
|
||||
if training_args.do_eval:
|
||||
max_target_length = data_args.val_max_target_length
|
||||
if "validation" not in dataset:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = dataset["validation"]
|
||||
if data_args.max_eval_samples is not None:
|
||||
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
||||
@@ -654,8 +664,6 @@ def main():
|
||||
|
||||
if training_args.do_predict:
|
||||
max_target_length = data_args.val_max_target_length
|
||||
if "test" not in dataset:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
predict_dataset = dataset["test"]
|
||||
if data_args.max_predict_samples is not None:
|
||||
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
|
||||
|
||||
@@ -262,8 +262,13 @@ class DataTrainingArguments:
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
if (
|
||||
self.dataset_name is None
|
||||
and self.train_file is None
|
||||
and self.validation_file is None
|
||||
and self.test_file is None
|
||||
):
|
||||
raise ValueError("Need either a dataset name or a training, validation, or test file.")
|
||||
else:
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
@@ -271,6 +276,9 @@ class DataTrainingArguments:
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||
if self.test_file is not None:
|
||||
extension = self.test_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
|
||||
if self.val_max_target_length is None:
|
||||
self.val_max_target_length = self.max_target_length
|
||||
|
||||
@@ -467,10 +475,16 @@ def main():
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize inputs and targets.
|
||||
if training_args.do_train:
|
||||
if "train" not in raw_datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
column_names = raw_datasets["train"].column_names
|
||||
elif training_args.do_eval:
|
||||
if "validation" not in raw_datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
column_names = raw_datasets["validation"].column_names
|
||||
elif training_args.do_predict:
|
||||
if "test" not in raw_datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
column_names = raw_datasets["test"].column_names
|
||||
else:
|
||||
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
||||
@@ -546,8 +560,6 @@ def main():
|
||||
return model_inputs
|
||||
|
||||
if training_args.do_train:
|
||||
if "train" not in raw_datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
train_dataset = raw_datasets["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
||||
@@ -564,8 +576,6 @@ def main():
|
||||
|
||||
if training_args.do_eval:
|
||||
max_target_length = data_args.val_max_target_length
|
||||
if "validation" not in raw_datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = raw_datasets["validation"]
|
||||
if data_args.max_eval_samples is not None:
|
||||
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
||||
@@ -582,8 +592,6 @@ def main():
|
||||
|
||||
if training_args.do_predict:
|
||||
max_target_length = data_args.val_max_target_length
|
||||
if "test" not in raw_datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
predict_dataset = raw_datasets["test"]
|
||||
if data_args.max_predict_samples is not None:
|
||||
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
|
||||
|
||||
Reference in New Issue
Block a user