diff --git a/examples/pytorch/language-modeling/run_clm_no_trainer.py b/examples/pytorch/language-modeling/run_clm_no_trainer.py index 3de3c7219c..b02a89e6df 100755 --- a/examples/pytorch/language-modeling/run_clm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_clm_no_trainer.py @@ -246,13 +246,16 @@ def parse_args(): else: if args.train_file is not None: extension = args.train_file.split(".")[-1] - assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file." + if extension not in ["csv", "json", "txt"]: + raise ValueError("`train_file` should be a csv, json or txt file.") if args.validation_file is not None: extension = args.validation_file.split(".")[-1] - assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." + if extension not in ["csv", "json", "txt"]: + raise ValueError("`validation_file` should be a csv, json or txt file.") if args.push_to_hub: - assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + if args.output_dir is None: + raise ValueError("Need an `output_dir` to create a repo when `--push_to_hub` is passed.") return args diff --git a/examples/pytorch/language-modeling/run_mlm_no_trainer.py b/examples/pytorch/language-modeling/run_mlm_no_trainer.py index 3d035fded5..749810cd31 100755 --- a/examples/pytorch/language-modeling/run_mlm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_mlm_no_trainer.py @@ -261,7 +261,8 @@ def parse_args(): raise ValueError("`validation_file` should be a csv, json or txt file.") if args.push_to_hub: - assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + if args.output_dir is None: + raise ValueError("Need an `output_dir` to create a repo when `--push_to_hub` is passed.") return args @@ -694,7 +695,7 @@ def main(): except OverflowError: perplexity = float("inf") - logger.info(f"epoch {epoch}: perplexity: {perplexity}") + logger.info(f"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}") if args.with_tracking: accelerator.log(