Replace assertion with ValueError exceptions in run_image_captioning_flax.py (#20365)
* replace 4 asserts with ValueError exception for control flow * Update examples/flax/image-captioning/run_image_captioning_flax.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update examples/flax/image-captioning/run_image_captioning_flax.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * reformatted file * uninstalled trasformers and applied make style Co-authored-by: Bibi <Bibi@katies-mac.local> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
@@ -298,10 +298,12 @@ class DataTrainingArguments:
|
|||||||
else:
|
else:
|
||||||
if self.train_file is not None:
|
if self.train_file is not None:
|
||||||
extension = self.train_file.split(".")[-1]
|
extension = self.train_file.split(".")[-1]
|
||||||
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
if extension not in ["csv", "json"]:
|
||||||
|
raise ValueError(f"`train_file` should be a csv or a json file, got {extension}.")
|
||||||
if self.validation_file is not None:
|
if self.validation_file is not None:
|
||||||
extension = self.validation_file.split(".")[-1]
|
extension = self.validation_file.split(".")[-1]
|
||||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
if extension not in ["csv", "json"]:
|
||||||
|
raise ValueError(f"`validation_file` should be a csv or a json file, got {extension}.")
|
||||||
if self.val_max_target_length is None:
|
if self.val_max_target_length is None:
|
||||||
self.val_max_target_length = self.max_target_length
|
self.val_max_target_length = self.max_target_length
|
||||||
|
|
||||||
@@ -502,7 +504,12 @@ def main():
|
|||||||
# Get the column names for input/target.
|
# Get the column names for input/target.
|
||||||
dataset_columns = image_captioning_name_mapping.get(data_args.dataset_name, None)
|
dataset_columns = image_captioning_name_mapping.get(data_args.dataset_name, None)
|
||||||
if data_args.image_column is None:
|
if data_args.image_column is None:
|
||||||
assert dataset_columns is not None
|
if dataset_columns is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"`--dataset_name` {data_args.dataset_name} not found in dataset '{data_args.dataset_name}'. Make sure"
|
||||||
|
" to set `--dataset_name` to the correct dataset name, one of"
|
||||||
|
f" {', '.join(image_captioning_name_mapping.keys())}."
|
||||||
|
)
|
||||||
image_column = dataset_columns[0]
|
image_column = dataset_columns[0]
|
||||||
else:
|
else:
|
||||||
image_column = data_args.image_column
|
image_column = data_args.image_column
|
||||||
@@ -511,7 +518,12 @@ def main():
|
|||||||
f"--image_column' value '{data_args.image_column}' needs to be one of: {', '.join(column_names)}"
|
f"--image_column' value '{data_args.image_column}' needs to be one of: {', '.join(column_names)}"
|
||||||
)
|
)
|
||||||
if data_args.caption_column is None:
|
if data_args.caption_column is None:
|
||||||
assert dataset_columns is not None
|
if dataset_columns is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"`--dataset_name` {data_args.dataset_name} not found in dataset '{data_args.dataset_name}'. Make sure"
|
||||||
|
" to set `--dataset_name` to the correct dataset name, one of"
|
||||||
|
f" {', '.join(image_captioning_name_mapping.keys())}."
|
||||||
|
)
|
||||||
caption_column = dataset_columns[1]
|
caption_column = dataset_columns[1]
|
||||||
else:
|
else:
|
||||||
caption_column = data_args.caption_column
|
caption_column = data_args.caption_column
|
||||||
|
|||||||
Reference in New Issue
Block a user