From ba47efbfe4093e4b14e60b5ff7b90f9490c35a0f Mon Sep 17 00:00:00 2001 From: Phuc Van Phan Date: Thu, 28 Sep 2023 15:14:17 +0700 Subject: [PATCH] docs: change assert to raise and some small docs (#26232) * docs: change assert to raise and some small docs * docs: add rule and some document * fix: fix bug * fix: fix bug * chorse: revert logging * chorse: revert --- examples/pytorch/language-modeling/run_clm_no_trainer.py | 9 ++++++--- examples/pytorch/language-modeling/run_mlm_no_trainer.py | 5 +++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/pytorch/language-modeling/run_clm_no_trainer.py b/examples/pytorch/language-modeling/run_clm_no_trainer.py index 3de3c7219c..b02a89e6df 100755 --- a/examples/pytorch/language-modeling/run_clm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_clm_no_trainer.py @@ -246,13 +246,16 @@ def parse_args(): else: if args.train_file is not None: extension = args.train_file.split(".")[-1] - assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file." + if extension not in ["csv", "json", "txt"]: + raise ValueError("`train_file` should be a csv, json or txt file.") if args.validation_file is not None: extension = args.validation_file.split(".")[-1] - assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." + if extension not in ["csv", "json", "txt"]: + raise ValueError("`validation_file` should be a csv, json or txt file.") if args.push_to_hub: - assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + if args.output_dir is None: + raise ValueError("Need an `output_dir` to create a repo when `--push_to_hub` is passed.") return args diff --git a/examples/pytorch/language-modeling/run_mlm_no_trainer.py b/examples/pytorch/language-modeling/run_mlm_no_trainer.py index 3d035fded5..749810cd31 100755 --- a/examples/pytorch/language-modeling/run_mlm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_mlm_no_trainer.py @@ -261,7 +261,8 @@ def parse_args(): raise ValueError("`validation_file` should be a csv, json or txt file.") if args.push_to_hub: - assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + if args.output_dir is None: + raise ValueError("Need an `output_dir` to create a repo when `--push_to_hub` is passed.") return args @@ -694,7 +695,7 @@ def main(): except OverflowError: perplexity = float("inf") - logger.info(f"epoch {epoch}: perplexity: {perplexity}") + logger.info(f"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}") if args.with_tracking: accelerator.log(