Replace assertion with ValueError exceptions in run_image_captioning_flax.py (#20365)
* replace 4 asserts with ValueError exception for control flow * Update examples/flax/image-captioning/run_image_captioning_flax.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update examples/flax/image-captioning/run_image_captioning_flax.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * reformatted file * uninstalled trasformers and applied make style Co-authored-by: Bibi <Bibi@katies-mac.local> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
@@ -298,10 +298,12 @@ class DataTrainingArguments:
|
||||
else:
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
||||
if extension not in ["csv", "json"]:
|
||||
raise ValueError(f"`train_file` should be a csv or a json file, got {extension}.")
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||
if extension not in ["csv", "json"]:
|
||||
raise ValueError(f"`validation_file` should be a csv or a json file, got {extension}.")
|
||||
if self.val_max_target_length is None:
|
||||
self.val_max_target_length = self.max_target_length
|
||||
|
||||
@@ -502,7 +504,12 @@ def main():
|
||||
# Get the column names for input/target.
|
||||
dataset_columns = image_captioning_name_mapping.get(data_args.dataset_name, None)
|
||||
if data_args.image_column is None:
|
||||
assert dataset_columns is not None
|
||||
if dataset_columns is None:
|
||||
raise ValueError(
|
||||
f"`--dataset_name` {data_args.dataset_name} not found in dataset '{data_args.dataset_name}'. Make sure"
|
||||
" to set `--dataset_name` to the correct dataset name, one of"
|
||||
f" {', '.join(image_captioning_name_mapping.keys())}."
|
||||
)
|
||||
image_column = dataset_columns[0]
|
||||
else:
|
||||
image_column = data_args.image_column
|
||||
@@ -511,7 +518,12 @@ def main():
|
||||
f"--image_column' value '{data_args.image_column}' needs to be one of: {', '.join(column_names)}"
|
||||
)
|
||||
if data_args.caption_column is None:
|
||||
assert dataset_columns is not None
|
||||
if dataset_columns is None:
|
||||
raise ValueError(
|
||||
f"`--dataset_name` {data_args.dataset_name} not found in dataset '{data_args.dataset_name}'. Make sure"
|
||||
" to set `--dataset_name` to the correct dataset name, one of"
|
||||
f" {', '.join(image_captioning_name_mapping.keys())}."
|
||||
)
|
||||
caption_column = dataset_columns[1]
|
||||
else:
|
||||
caption_column = data_args.caption_column
|
||||
|
||||
Reference in New Issue
Block a user