fix: Allow only test_file in pytorch and flax summarization (#22293)
allow only test_file in pytorch and flax summarization
This commit is contained in:
@@ -308,8 +308,13 @@ class DataTrainingArguments:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
if (
|
||||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
self.dataset_name is None
|
||||||
|
and self.train_file is None
|
||||||
|
and self.validation_file is None
|
||||||
|
and self.test_file is None
|
||||||
|
):
|
||||||
|
raise ValueError("Need either a dataset name or a training, validation, or test file.")
|
||||||
else:
|
else:
|
||||||
if self.train_file is not None:
|
if self.train_file is not None:
|
||||||
extension = self.train_file.split(".")[-1]
|
extension = self.train_file.split(".")[-1]
|
||||||
@@ -317,6 +322,9 @@ class DataTrainingArguments:
|
|||||||
if self.validation_file is not None:
|
if self.validation_file is not None:
|
||||||
extension = self.validation_file.split(".")[-1]
|
extension = self.validation_file.split(".")[-1]
|
||||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||||
|
if self.test_file is not None:
|
||||||
|
extension = self.test_file.split(".")[-1]
|
||||||
|
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
|
||||||
if self.val_max_target_length is None:
|
if self.val_max_target_length is None:
|
||||||
self.val_max_target_length = self.max_target_length
|
self.val_max_target_length = self.max_target_length
|
||||||
|
|
||||||
@@ -553,10 +561,16 @@ def main():
|
|||||||
# Preprocessing the datasets.
|
# Preprocessing the datasets.
|
||||||
# We need to tokenize inputs and targets.
|
# We need to tokenize inputs and targets.
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
|
if "train" not in dataset:
|
||||||
|
raise ValueError("--do_train requires a train dataset")
|
||||||
column_names = dataset["train"].column_names
|
column_names = dataset["train"].column_names
|
||||||
elif training_args.do_eval:
|
elif training_args.do_eval:
|
||||||
|
if "validation" not in dataset:
|
||||||
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
column_names = dataset["validation"].column_names
|
column_names = dataset["validation"].column_names
|
||||||
elif training_args.do_predict:
|
elif training_args.do_predict:
|
||||||
|
if "test" not in dataset:
|
||||||
|
raise ValueError("--do_predict requires a test dataset")
|
||||||
column_names = dataset["test"].column_names
|
column_names = dataset["test"].column_names
|
||||||
else:
|
else:
|
||||||
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
||||||
@@ -620,8 +634,6 @@ def main():
|
|||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if "train" not in dataset:
|
|
||||||
raise ValueError("--do_train requires a train dataset")
|
|
||||||
train_dataset = dataset["train"]
|
train_dataset = dataset["train"]
|
||||||
if data_args.max_train_samples is not None:
|
if data_args.max_train_samples is not None:
|
||||||
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
||||||
@@ -637,8 +649,6 @@ def main():
|
|||||||
|
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
max_target_length = data_args.val_max_target_length
|
max_target_length = data_args.val_max_target_length
|
||||||
if "validation" not in dataset:
|
|
||||||
raise ValueError("--do_eval requires a validation dataset")
|
|
||||||
eval_dataset = dataset["validation"]
|
eval_dataset = dataset["validation"]
|
||||||
if data_args.max_eval_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
||||||
@@ -654,8 +664,6 @@ def main():
|
|||||||
|
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
max_target_length = data_args.val_max_target_length
|
max_target_length = data_args.val_max_target_length
|
||||||
if "test" not in dataset:
|
|
||||||
raise ValueError("--do_predict requires a test dataset")
|
|
||||||
predict_dataset = dataset["test"]
|
predict_dataset = dataset["test"]
|
||||||
if data_args.max_predict_samples is not None:
|
if data_args.max_predict_samples is not None:
|
||||||
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
|
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
|
||||||
|
|||||||
@@ -262,8 +262,13 @@ class DataTrainingArguments:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
if (
|
||||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
self.dataset_name is None
|
||||||
|
and self.train_file is None
|
||||||
|
and self.validation_file is None
|
||||||
|
and self.test_file is None
|
||||||
|
):
|
||||||
|
raise ValueError("Need either a dataset name or a training, validation, or test file.")
|
||||||
else:
|
else:
|
||||||
if self.train_file is not None:
|
if self.train_file is not None:
|
||||||
extension = self.train_file.split(".")[-1]
|
extension = self.train_file.split(".")[-1]
|
||||||
@@ -271,6 +276,9 @@ class DataTrainingArguments:
|
|||||||
if self.validation_file is not None:
|
if self.validation_file is not None:
|
||||||
extension = self.validation_file.split(".")[-1]
|
extension = self.validation_file.split(".")[-1]
|
||||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||||
|
if self.test_file is not None:
|
||||||
|
extension = self.test_file.split(".")[-1]
|
||||||
|
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
|
||||||
if self.val_max_target_length is None:
|
if self.val_max_target_length is None:
|
||||||
self.val_max_target_length = self.max_target_length
|
self.val_max_target_length = self.max_target_length
|
||||||
|
|
||||||
@@ -467,10 +475,16 @@ def main():
|
|||||||
# Preprocessing the datasets.
|
# Preprocessing the datasets.
|
||||||
# We need to tokenize inputs and targets.
|
# We need to tokenize inputs and targets.
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
|
if "train" not in raw_datasets:
|
||||||
|
raise ValueError("--do_train requires a train dataset")
|
||||||
column_names = raw_datasets["train"].column_names
|
column_names = raw_datasets["train"].column_names
|
||||||
elif training_args.do_eval:
|
elif training_args.do_eval:
|
||||||
|
if "validation" not in raw_datasets:
|
||||||
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
column_names = raw_datasets["validation"].column_names
|
column_names = raw_datasets["validation"].column_names
|
||||||
elif training_args.do_predict:
|
elif training_args.do_predict:
|
||||||
|
if "test" not in raw_datasets:
|
||||||
|
raise ValueError("--do_predict requires a test dataset")
|
||||||
column_names = raw_datasets["test"].column_names
|
column_names = raw_datasets["test"].column_names
|
||||||
else:
|
else:
|
||||||
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
||||||
@@ -546,8 +560,6 @@ def main():
|
|||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if "train" not in raw_datasets:
|
|
||||||
raise ValueError("--do_train requires a train dataset")
|
|
||||||
train_dataset = raw_datasets["train"]
|
train_dataset = raw_datasets["train"]
|
||||||
if data_args.max_train_samples is not None:
|
if data_args.max_train_samples is not None:
|
||||||
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
||||||
@@ -564,8 +576,6 @@ def main():
|
|||||||
|
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
max_target_length = data_args.val_max_target_length
|
max_target_length = data_args.val_max_target_length
|
||||||
if "validation" not in raw_datasets:
|
|
||||||
raise ValueError("--do_eval requires a validation dataset")
|
|
||||||
eval_dataset = raw_datasets["validation"]
|
eval_dataset = raw_datasets["validation"]
|
||||||
if data_args.max_eval_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
||||||
@@ -582,8 +592,6 @@ def main():
|
|||||||
|
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
max_target_length = data_args.val_max_target_length
|
max_target_length = data_args.val_max_target_length
|
||||||
if "test" not in raw_datasets:
|
|
||||||
raise ValueError("--do_predict requires a test dataset")
|
|
||||||
predict_dataset = raw_datasets["test"]
|
predict_dataset = raw_datasets["test"]
|
||||||
if data_args.max_predict_samples is not None:
|
if data_args.max_predict_samples is not None:
|
||||||
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
|
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
|
||||||
|
|||||||
Reference in New Issue
Block a user