From ebd48c6de544e22dd4c5743fb27039bf24b811e1 Mon Sep 17 00:00:00 2001 From: Emanuel Huber Date: Tue, 26 Oct 2021 18:14:29 -0300 Subject: [PATCH] Replace assertions with ValueError exception (#14142) Updated masked-language modeling examples in pytorch with convention defined by #12789 --- examples/pytorch/language-modeling/run_mlm.py | 6 ++++-- examples/pytorch/language-modeling/run_mlm_no_trainer.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index 49c086d71e..dec749d850 100755 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -188,10 +188,12 @@ class DataTrainingArguments: else: if self.train_file is not None: extension = self.train_file.split(".")[-1] - assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." + if extension not in ["csv", "json", "txt"]: + raise ValueError("`train_file` should be a csv, a json or a txt file.") if self.validation_file is not None: extension = self.validation_file.split(".")[-1] - assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." + if extension not in ["csv", "json", "txt"]: + raise ValueError("`validation_file` should be a csv, a json or a txt file.") def main(): diff --git a/examples/pytorch/language-modeling/run_mlm_no_trainer.py b/examples/pytorch/language-modeling/run_mlm_no_trainer.py index 3e2241495b..cf2841ab5f 100755 --- a/examples/pytorch/language-modeling/run_mlm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_mlm_no_trainer.py @@ -201,10 +201,12 @@ def parse_args(): else: if args.train_file is not None: extension = args.train_file.split(".")[-1] - assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file." + if extension not in ["csv", "json", "txt"]: + raise ValueError("`train_file` should be a csv, json or txt file.") if args.validation_file is not None: extension = args.validation_file.split(".")[-1] - assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." + if extension not in ["csv", "json", "txt"]: + raise ValueError("`validation_file` should be a csv, json or txt file.") if args.push_to_hub: assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."