From a0c62d249303a68f5336e3f9a96ecf9241d7abbe Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 18 Nov 2020 12:15:26 -0500 Subject: [PATCH] Fix training from scratch in new scripts (#8623) --- examples/language-modeling/run_clm.py | 7 +++++-- examples/language-modeling/run_mlm.py | 7 +++++-- examples/language-modeling/run_mlm_wwm.py | 7 +++++-- examples/language-modeling/run_plm.py | 7 +++++-- .../run_{{cookiecutter.example_shortcut}}.py | 9 +++++++++ 5 files changed, 29 insertions(+), 8 deletions(-) diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index 396631b9ff..2abdecdd1b 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -313,9 +313,12 @@ def main(): # Training if training_args.do_train: - trainer.train( - model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None + model_path = ( + model_args.model_name_or_path + if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) + else None ) + trainer.train(model_path=model_path) trainer.save_model() # Saves the tokenizer too for easy upload # Evaluation diff --git a/examples/language-modeling/run_mlm.py b/examples/language-modeling/run_mlm.py index dfc2614a72..664128eaf9 100644 --- a/examples/language-modeling/run_mlm.py +++ b/examples/language-modeling/run_mlm.py @@ -354,9 +354,12 @@ def main(): # Training if training_args.do_train: - trainer.train( - model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None + model_path = ( + model_args.model_name_or_path + if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) + else None ) + trainer.train(model_path=model_path) trainer.save_model() # Saves the tokenizer too for easy upload # Evaluation diff --git a/examples/language-modeling/run_mlm_wwm.py b/examples/language-modeling/run_mlm_wwm.py index b2ffcc34ac..e7c6505fc9 100644 --- a/examples/language-modeling/run_mlm_wwm.py +++ b/examples/language-modeling/run_mlm_wwm.py @@ -302,9 +302,12 @@ def main(): # Training if training_args.do_train: - trainer.train( - model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None + model_path = ( + model_args.model_name_or_path + if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) + else None ) + trainer.train(model_path=model_path) trainer.save_model() # Saves the tokenizer too for easy upload # Evaluation diff --git a/examples/language-modeling/run_plm.py b/examples/language-modeling/run_plm.py index 65700a415c..0e264115d8 100644 --- a/examples/language-modeling/run_plm.py +++ b/examples/language-modeling/run_plm.py @@ -344,9 +344,12 @@ def main(): # Training if training_args.do_train: - trainer.train( - model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None + model_path = ( + model_args.model_name_or_path + if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) + else None ) + trainer.train(model_path=model_path) trainer.save_model() # Saves the tokenizer too for easy upload # Evaluation diff --git a/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py b/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py index a659ed95bb..cefa064cad 100644 --- a/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py +++ b/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py @@ -307,9 +307,18 @@ def main(): # Training if training_args.do_train: +{%- if cookiecutter.can_train_from_scratch == "False" %} trainer.train( model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None ) +{%- elif cookiecutter.can_train_from_scratch == "True" %} + model_path = ( + model_args.model_name_or_path + if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) + else None + ) + trainer.train(model_path=model_path) +{% endif %} trainer.save_model() # Saves the tokenizer too for easy upload # Evaluation