diff --git a/examples/flax/image-captioning/run_image_captioning_flax.py b/examples/flax/image-captioning/run_image_captioning_flax.py index 493caa6c7b..1258eba49f 100644 --- a/examples/flax/image-captioning/run_image_captioning_flax.py +++ b/examples/flax/image-captioning/run_image_captioning_flax.py @@ -298,10 +298,12 @@ class DataTrainingArguments: else: if self.train_file is not None: extension = self.train_file.split(".")[-1] - assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." + if extension not in ["csv", "json"]: + raise ValueError(f"`train_file` should be a csv or a json file, got {extension}.") if self.validation_file is not None: extension = self.validation_file.split(".")[-1] - assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + if extension not in ["csv", "json"]: + raise ValueError(f"`validation_file` should be a csv or a json file, got {extension}.") if self.val_max_target_length is None: self.val_max_target_length = self.max_target_length @@ -502,7 +504,12 @@ def main(): # Get the column names for input/target. dataset_columns = image_captioning_name_mapping.get(data_args.dataset_name, None) if data_args.image_column is None: - assert dataset_columns is not None + if dataset_columns is None: + raise ValueError( + f"`--dataset_name` {data_args.dataset_name} not found in dataset '{data_args.dataset_name}'. Make sure" + " to set `--dataset_name` to the correct dataset name, one of" + f" {', '.join(image_captioning_name_mapping.keys())}." + ) image_column = dataset_columns[0] else: image_column = data_args.image_column @@ -511,7 +518,12 @@ def main(): f"--image_column' value '{data_args.image_column}' needs to be one of: {', '.join(column_names)}" ) if data_args.caption_column is None: - assert dataset_columns is not None + if dataset_columns is None: + raise ValueError( + f"`--dataset_name` {data_args.dataset_name} not found in dataset '{data_args.dataset_name}'. Make sure" + " to set `--dataset_name` to the correct dataset name, one of" + f" {', '.join(image_captioning_name_mapping.keys())}." + ) caption_column = dataset_columns[1] else: caption_column = data_args.caption_column