From 667ccea72235504ab7876024e4f8c113ca62190f Mon Sep 17 00:00:00 2001 From: Katie Le <54815905+katiele47@users.noreply.github.com> Date: Mon, 28 Nov 2022 10:06:25 -0500 Subject: [PATCH] Replace assertion with ValueError exceptions in run_image_captioning_flax.py (#20365) * replace 4 asserts with ValueError exception for control flow * Update examples/flax/image-captioning/run_image_captioning_flax.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update examples/flax/image-captioning/run_image_captioning_flax.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * reformatted file * uninstalled trasformers and applied make style Co-authored-by: Bibi Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --- .../run_image_captioning_flax.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/examples/flax/image-captioning/run_image_captioning_flax.py b/examples/flax/image-captioning/run_image_captioning_flax.py index 493caa6c7b..1258eba49f 100644 --- a/examples/flax/image-captioning/run_image_captioning_flax.py +++ b/examples/flax/image-captioning/run_image_captioning_flax.py @@ -298,10 +298,12 @@ class DataTrainingArguments: else: if self.train_file is not None: extension = self.train_file.split(".")[-1] - assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." + if extension not in ["csv", "json"]: + raise ValueError(f"`train_file` should be a csv or a json file, got {extension}.") if self.validation_file is not None: extension = self.validation_file.split(".")[-1] - assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + if extension not in ["csv", "json"]: + raise ValueError(f"`validation_file` should be a csv or a json file, got {extension}.") if self.val_max_target_length is None: self.val_max_target_length = self.max_target_length @@ -502,7 +504,12 @@ def main(): # Get the column names for input/target. dataset_columns = image_captioning_name_mapping.get(data_args.dataset_name, None) if data_args.image_column is None: - assert dataset_columns is not None + if dataset_columns is None: + raise ValueError( + f"`--dataset_name` {data_args.dataset_name} not found in dataset '{data_args.dataset_name}'. Make sure" + " to set `--dataset_name` to the correct dataset name, one of" + f" {', '.join(image_captioning_name_mapping.keys())}." + ) image_column = dataset_columns[0] else: image_column = data_args.image_column @@ -511,7 +518,12 @@ def main(): f"--image_column' value '{data_args.image_column}' needs to be one of: {', '.join(column_names)}" ) if data_args.caption_column is None: - assert dataset_columns is not None + if dataset_columns is None: + raise ValueError( + f"`--dataset_name` {data_args.dataset_name} not found in dataset '{data_args.dataset_name}'. Make sure" + " to set `--dataset_name` to the correct dataset name, one of" + f" {', '.join(image_captioning_name_mapping.keys())}." + ) caption_column = dataset_columns[1] else: caption_column = data_args.caption_column