From caf4abf768cfbefa9f004da1a918c86c92881f8a Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 25 Jan 2021 12:03:51 -0500 Subject: [PATCH] Auto-resume training from checkpoint (#9776) * Auto-resume training from checkpoint * Update examples/text-classification/run_glue.py Co-authored-by: Lysandre Debut * Roll out to other examples Co-authored-by: Lysandre Debut --- examples/language-modeling/run_clm.py | 37 +++++++------- examples/language-modeling/run_mlm.py | 37 +++++++------- examples/language-modeling/run_mlm_wwm.py | 37 +++++++------- examples/language-modeling/run_plm.py | 37 +++++++------- examples/multiple-choice/run_swag.py | 36 ++++++++------ examples/question-answering/run_qa.py | 36 ++++++++------ .../question-answering/run_qa_beam_search.py | 36 ++++++++------ examples/seq2seq/run_seq2seq.py | 36 ++++++++------ examples/text-classification/run_glue.py | 36 ++++++++------ examples/token-classification/run_ner.py | 36 ++++++++------ src/transformers/trainer_utils.py | 11 +++++ .../run_{{cookiecutter.example_shortcut}}.py | 48 +++++++++++-------- 12 files changed, 255 insertions(+), 168 deletions(-) diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index 2cb1e45414..57f8f3cd2d 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -42,7 +42,7 @@ from transformers import ( default_data_collator, set_seed, ) -from transformers.trainer_utils import is_main_process +from transformers.trainer_utils import get_last_checkpoint, is_main_process logger = logging.getLogger(__name__) @@ -160,16 +160,20 @@ def main(): else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() - if ( - os.path.exists(training_args.output_dir) - and os.listdir(training_args.output_dir) - and training_args.do_train - and not training_args.overwrite_output_dir - ): - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty." - "Use --overwrite_output_dir to overcome." - ) + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) # Setup logging logging.basicConfig( @@ -356,11 +360,12 @@ def main(): # Training if training_args.do_train: - model_path = ( - model_args.model_name_or_path - if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) - else None - ) + if last_checkpoint is not None: + model_path = last_checkpoint + elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path): + model_path = model_args.model_name_or_path + else: + model_path = None train_result = trainer.train(model_path=model_path) trainer.save_model() # Saves the tokenizer too for easy upload diff --git a/examples/language-modeling/run_mlm.py b/examples/language-modeling/run_mlm.py index 223b8508fb..bf72a09a30 100644 --- a/examples/language-modeling/run_mlm.py +++ b/examples/language-modeling/run_mlm.py @@ -42,7 +42,7 @@ from transformers import ( TrainingArguments, set_seed, ) -from transformers.trainer_utils import is_main_process +from transformers.trainer_utils import get_last_checkpoint, is_main_process logger = logging.getLogger(__name__) @@ -171,16 +171,20 @@ def main(): else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() - if ( - os.path.exists(training_args.output_dir) - and os.listdir(training_args.output_dir) - and training_args.do_train - and not training_args.overwrite_output_dir - ): - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty." - "Use --overwrite_output_dir to overcome." - ) + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) # Setup logging logging.basicConfig( @@ -397,11 +401,12 @@ def main(): # Training if training_args.do_train: - model_path = ( - model_args.model_name_or_path - if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) - else None - ) + if last_checkpoint is not None: + model_path = last_checkpoint + elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path): + model_path = model_args.model_name_or_path + else: + model_path = None train_result = trainer.train(model_path=model_path) trainer.save_model() # Saves the tokenizer too for easy upload diff --git a/examples/language-modeling/run_mlm_wwm.py b/examples/language-modeling/run_mlm_wwm.py index 5912cdd28b..ab80b25c3a 100644 --- a/examples/language-modeling/run_mlm_wwm.py +++ b/examples/language-modeling/run_mlm_wwm.py @@ -44,7 +44,7 @@ from transformers import ( TrainingArguments, set_seed, ) -from transformers.trainer_utils import is_main_process +from transformers.trainer_utils import get_last_checkpoint, is_main_process logger = logging.getLogger(__name__) @@ -184,16 +184,20 @@ def main(): else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() - if ( - os.path.exists(training_args.output_dir) - and os.listdir(training_args.output_dir) - and training_args.do_train - and not training_args.overwrite_output_dir - ): - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty." - "Use --overwrite_output_dir to overcome." - ) + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) # Setup logging logging.basicConfig( @@ -349,11 +353,12 @@ def main(): # Training if training_args.do_train: - model_path = ( - model_args.model_name_or_path - if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) - else None - ) + if last_checkpoint is not None: + model_path = last_checkpoint + elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path): + model_path = model_args.model_name_or_path + else: + model_path = None train_result = trainer.train(model_path=model_path) trainer.save_model() # Saves the tokenizer too for easy upload diff --git a/examples/language-modeling/run_plm.py b/examples/language-modeling/run_plm.py index 7a11a3f2c9..589b53ded2 100644 --- a/examples/language-modeling/run_plm.py +++ b/examples/language-modeling/run_plm.py @@ -38,7 +38,7 @@ from transformers import ( XLNetLMHeadModel, set_seed, ) -from transformers.trainer_utils import is_main_process +from transformers.trainer_utils import get_last_checkpoint, is_main_process logger = logging.getLogger(__name__) @@ -168,16 +168,20 @@ def main(): else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() - if ( - os.path.exists(training_args.output_dir) - and os.listdir(training_args.output_dir) - and training_args.do_train - and not training_args.overwrite_output_dir - ): - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty." - "Use --overwrite_output_dir to overcome." - ) + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) # Setup logging logging.basicConfig( @@ -378,11 +382,12 @@ def main(): # Training if training_args.do_train: - model_path = ( - model_args.model_name_or_path - if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) - else None - ) + if last_checkpoint is not None: + model_path = last_checkpoint + elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path): + model_path = model_args.model_name_or_path + else: + model_path = None train_result = trainer.train(model_path=model_path) trainer.save_model() # Saves the tokenizer too for easy upload diff --git a/examples/multiple-choice/run_swag.py b/examples/multiple-choice/run_swag.py index bd9ce634e1..4ea2296ba8 100644 --- a/examples/multiple-choice/run_swag.py +++ b/examples/multiple-choice/run_swag.py @@ -39,7 +39,7 @@ from transformers import ( set_seed, ) from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase -from transformers.trainer_utils import is_main_process +from transformers.trainer_utils import get_last_checkpoint, is_main_process logger = logging.getLogger(__name__) @@ -194,16 +194,20 @@ def main(): else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() - if ( - os.path.exists(training_args.output_dir) - and os.listdir(training_args.output_dir) - and training_args.do_train - and not training_args.overwrite_output_dir - ): - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty." - "Use --overwrite_output_dir to overcome." - ) + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) # Setup logging logging.basicConfig( @@ -334,9 +338,13 @@ def main(): # Training if training_args.do_train: - train_result = trainer.train( - model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None - ) + if last_checkpoint is not None: + model_path = last_checkpoint + elif os.path.isdir(model_args.model_name_or_path): + model_path = model_args.model_name_or_path + else: + model_path = None + train_result = trainer.train(model_path=model_path) trainer.save_model() # Saves the tokenizer too for easy upload output_train_file = os.path.join(training_args.output_dir, "train_results.txt") diff --git a/examples/question-answering/run_qa.py b/examples/question-answering/run_qa.py index dc3cce05b9..18a82d66fd 100644 --- a/examples/question-answering/run_qa.py +++ b/examples/question-answering/run_qa.py @@ -39,7 +39,7 @@ from transformers import ( default_data_collator, set_seed, ) -from transformers.trainer_utils import is_main_process +from transformers.trainer_utils import get_last_checkpoint, is_main_process from utils_qa import postprocess_qa_predictions @@ -169,16 +169,20 @@ def main(): else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() - if ( - os.path.exists(training_args.output_dir) - and os.listdir(training_args.output_dir) - and training_args.do_train - and not training_args.overwrite_output_dir - ): - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty." - "Use --overwrite_output_dir to overcome." - ) + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) # Setup logging logging.basicConfig( @@ -453,9 +457,13 @@ def main(): # Training if training_args.do_train: - train_result = trainer.train( - model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None - ) + if last_checkpoint is not None: + model_path = last_checkpoint + elif os.path.isdir(model_args.model_name_or_path): + model_path = model_args.model_name_or_path + else: + model_path = None + train_result = trainer.train(model_path=model_path) trainer.save_model() # Saves the tokenizer too for easy upload output_train_file = os.path.join(training_args.output_dir, "train_results.txt") diff --git a/examples/question-answering/run_qa_beam_search.py b/examples/question-answering/run_qa_beam_search.py index 6d343ce766..7d567620a9 100644 --- a/examples/question-answering/run_qa_beam_search.py +++ b/examples/question-answering/run_qa_beam_search.py @@ -38,7 +38,7 @@ from transformers import ( default_data_collator, set_seed, ) -from transformers.trainer_utils import is_main_process +from transformers.trainer_utils import get_last_checkpoint, is_main_process from utils_qa import postprocess_qa_predictions_with_beam_search @@ -168,16 +168,20 @@ def main(): else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() - if ( - os.path.exists(training_args.output_dir) - and os.listdir(training_args.output_dir) - and training_args.do_train - and not training_args.overwrite_output_dir - ): - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty." - "Use --overwrite_output_dir to overcome." - ) + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) # Setup logging logging.basicConfig( @@ -492,9 +496,13 @@ def main(): # Training if training_args.do_train: - train_result = trainer.train( - model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None - ) + if last_checkpoint is not None: + model_path = last_checkpoint + elif os.path.isdir(model_args.model_name_or_path): + model_path = model_args.model_name_or_path + else: + model_path = None + train_result = trainer.train(model_path=model_path) trainer.save_model() # Saves the tokenizer too for easy upload output_train_file = os.path.join(training_args.output_dir, "train_results.txt") diff --git a/examples/seq2seq/run_seq2seq.py b/examples/seq2seq/run_seq2seq.py index b5a04fdb73..f92aa1f91f 100644 --- a/examples/seq2seq/run_seq2seq.py +++ b/examples/seq2seq/run_seq2seq.py @@ -40,7 +40,7 @@ from transformers import ( default_data_collator, set_seed, ) -from transformers.trainer_utils import is_main_process +from transformers.trainer_utils import get_last_checkpoint, is_main_process logger = logging.getLogger(__name__) @@ -225,16 +225,20 @@ def main(): else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() - if ( - os.path.exists(training_args.output_dir) - and os.listdir(training_args.output_dir) - and training_args.do_train - and not training_args.overwrite_output_dir - ): - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty." - "Use --overwrite_output_dir to overcome." - ) + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) # Setup logging logging.basicConfig( @@ -481,9 +485,13 @@ def main(): # Training if training_args.do_train: - train_result = trainer.train( - model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None - ) + if last_checkpoint is not None: + model_path = last_checkpoint + elif os.path.isdir(model_args.model_name_or_path): + model_path = model_args.model_name_or_path + else: + model_path = None + train_result = trainer.train(model_path=model_path) trainer.save_model() # Saves the tokenizer too for easy upload output_train_file = os.path.join(training_args.output_dir, "train_results.txt") diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index 60d33786b4..963d41cec0 100644 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -38,7 +38,7 @@ from transformers import ( default_data_collator, set_seed, ) -from transformers.trainer_utils import is_main_process +from transformers.trainer_utils import get_last_checkpoint, is_main_process task_to_keys = { @@ -160,16 +160,20 @@ def main(): else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() - if ( - os.path.exists(training_args.output_dir) - and os.listdir(training_args.output_dir) - and training_args.do_train - and not training_args.overwrite_output_dir - ): - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty. " - "Use --overwrite_output_dir to overcome." - ) + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) # Setup logging logging.basicConfig( @@ -385,9 +389,13 @@ def main(): # Training if training_args.do_train: - train_result = trainer.train( - model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None - ) + if last_checkpoint is not None: + model_path = last_checkpoint + elif os.path.isdir(model_args.model_name_or_path): + model_path = model_args.model_name_or_path + else: + model_path = None + train_result = trainer.train(model_path=model_path) metrics = train_result.metrics trainer.save_model() # Saves the tokenizer too for easy upload diff --git a/examples/token-classification/run_ner.py b/examples/token-classification/run_ner.py index 807d2ee7c4..8253e99852 100644 --- a/examples/token-classification/run_ner.py +++ b/examples/token-classification/run_ner.py @@ -39,7 +39,7 @@ from transformers import ( TrainingArguments, set_seed, ) -from transformers.trainer_utils import is_main_process +from transformers.trainer_utils import get_last_checkpoint, is_main_process logger = logging.getLogger(__name__) @@ -154,16 +154,20 @@ def main(): else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() - if ( - os.path.exists(training_args.output_dir) - and os.listdir(training_args.output_dir) - and training_args.do_train - and not training_args.overwrite_output_dir - ): - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty." - "Use --overwrite_output_dir to overcome." - ) + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) # Setup logging logging.basicConfig( @@ -374,9 +378,13 @@ def main(): # Training if training_args.do_train: - train_result = trainer.train( - model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None - ) + if last_checkpoint is not None: + model_path = last_checkpoint + elif os.path.isdir(model_args.model_name_or_path): + model_path = model_args.model_name_or_path + else: + model_path = None + train_result = trainer.train(model_path=model_path) trainer.save_model() # Saves the tokenizer too for easy upload output_train_file = os.path.join(training_args.output_dir, "train_results.txt") diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index ff41391437..e032ede0ea 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -17,7 +17,9 @@ Utilities for the Trainer and TFTrainer class. Should be independent from PyTorc """ import copy +import os import random +import re import time from typing import Any, Dict, NamedTuple, Optional, Tuple, Union @@ -75,6 +77,15 @@ class TrainOutput(NamedTuple): PREFIX_CHECKPOINT_DIR = "checkpoint" +_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d)+$") + + +def get_last_checkpoint(folder): + content = os.listdir(folder) + checkpoints = [path for path in content if _re_checkpoint.search(path) is not None and os.path.isdir(path)] + if len(checkpoints) == 0: + return + return max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0])) class EvaluationStrategy(ExplicitEnum): diff --git a/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py b/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py index 6983b789b0..107e7fe45c 100644 --- a/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py +++ b/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py @@ -39,7 +39,7 @@ from transformers import ( default_data_collator, set_seed, ) -from transformers.trainer_utils import is_main_process +from transformers.trainer_utils import get_last_checkpoint, is_main_process logger = logging.getLogger(__name__) @@ -168,16 +168,20 @@ def main(): else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() - if ( - os.path.exists(training_args.output_dir) - and os.listdir(training_args.output_dir) - and training_args.do_train - and not training_args.overwrite_output_dir - ): - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty." - "Use --overwrite_output_dir to overcome." - ) + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) # Setup logging logging.basicConfig( @@ -334,17 +338,21 @@ def main(): # Training if training_args.do_train: {%- if cookiecutter.can_train_from_scratch == "False" %} - train_result = trainer.train( - model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None - ) + if last_checkpoint is not None: + model_path = last_checkpoint + elif os.path.isdir(model_args.model_name_or_path): + model_path = model_args.model_name_or_path + else: + model_path = None {%- elif cookiecutter.can_train_from_scratch == "True" %} - model_path = ( - model_args.model_name_or_path - if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) - else None - ) - train_result = trainer.train(model_path=model_path) + if last_checkpoint is not None: + model_path = last_checkpoint + elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path): + model_path = model_args.model_name_or_path + else: + model_path = None {% endif %} + train_result = trainer.train(model_path=model_path) trainer.save_model() # Saves the tokenizer too for easy upload output_train_file = os.path.join(training_args.output_dir, "train_results.txt")