Auto-resume training from checkpoint (#9776)

* Auto-resume training from checkpoint

* Update examples/text-classification/run_glue.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Roll out to other examples

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Sylvain Gugger
2021-01-25 12:03:51 -05:00
committed by GitHub
parent 0f443436fb
commit caf4abf768
12 changed files with 255 additions and 168 deletions

View File

@@ -42,7 +42,7 @@ from transformers import (
default_data_collator, default_data_collator,
set_seed, set_seed,
) )
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -160,16 +160,20 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( # Detecting last checkpoint.
os.path.exists(training_args.output_dir) last_checkpoint = None
and os.listdir(training_args.output_dir) if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
and training_args.do_train last_checkpoint = get_last_checkpoint(training_args.output_dir)
and not training_args.overwrite_output_dir if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
):
raise ValueError( raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty." f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome." "Use --overwrite_output_dir to overcome."
) )
elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
@@ -356,11 +360,12 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
model_path = ( if last_checkpoint is not None:
model_args.model_name_or_path model_path = last_checkpoint
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
else None model_path = model_args.model_name_or_path
) else:
model_path = None
train_result = trainer.train(model_path=model_path) train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload

View File

@@ -42,7 +42,7 @@ from transformers import (
TrainingArguments, TrainingArguments,
set_seed, set_seed,
) )
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -171,16 +171,20 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( # Detecting last checkpoint.
os.path.exists(training_args.output_dir) last_checkpoint = None
and os.listdir(training_args.output_dir) if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
and training_args.do_train last_checkpoint = get_last_checkpoint(training_args.output_dir)
and not training_args.overwrite_output_dir if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
):
raise ValueError( raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty." f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome." "Use --overwrite_output_dir to overcome."
) )
elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
@@ -397,11 +401,12 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
model_path = ( if last_checkpoint is not None:
model_args.model_name_or_path model_path = last_checkpoint
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
else None model_path = model_args.model_name_or_path
) else:
model_path = None
train_result = trainer.train(model_path=model_path) train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload

View File

@@ -44,7 +44,7 @@ from transformers import (
TrainingArguments, TrainingArguments,
set_seed, set_seed,
) )
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -184,16 +184,20 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( # Detecting last checkpoint.
os.path.exists(training_args.output_dir) last_checkpoint = None
and os.listdir(training_args.output_dir) if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
and training_args.do_train last_checkpoint = get_last_checkpoint(training_args.output_dir)
and not training_args.overwrite_output_dir if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
):
raise ValueError( raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty." f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome." "Use --overwrite_output_dir to overcome."
) )
elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
@@ -349,11 +353,12 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
model_path = ( if last_checkpoint is not None:
model_args.model_name_or_path model_path = last_checkpoint
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
else None model_path = model_args.model_name_or_path
) else:
model_path = None
train_result = trainer.train(model_path=model_path) train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload

View File

@@ -38,7 +38,7 @@ from transformers import (
XLNetLMHeadModel, XLNetLMHeadModel,
set_seed, set_seed,
) )
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -168,16 +168,20 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( # Detecting last checkpoint.
os.path.exists(training_args.output_dir) last_checkpoint = None
and os.listdir(training_args.output_dir) if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
and training_args.do_train last_checkpoint = get_last_checkpoint(training_args.output_dir)
and not training_args.overwrite_output_dir if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
):
raise ValueError( raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty." f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome." "Use --overwrite_output_dir to overcome."
) )
elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
@@ -378,11 +382,12 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
model_path = ( if last_checkpoint is not None:
model_args.model_name_or_path model_path = last_checkpoint
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
else None model_path = model_args.model_name_or_path
) else:
model_path = None
train_result = trainer.train(model_path=model_path) train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload

View File

@@ -39,7 +39,7 @@ from transformers import (
set_seed, set_seed,
) )
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -194,16 +194,20 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( # Detecting last checkpoint.
os.path.exists(training_args.output_dir) last_checkpoint = None
and os.listdir(training_args.output_dir) if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
and training_args.do_train last_checkpoint = get_last_checkpoint(training_args.output_dir)
and not training_args.overwrite_output_dir if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
):
raise ValueError( raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty." f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome." "Use --overwrite_output_dir to overcome."
) )
elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
@@ -334,9 +338,13 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
train_result = trainer.train( if last_checkpoint is not None:
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None model_path = last_checkpoint
) elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path
else:
model_path = None
train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") output_train_file = os.path.join(training_args.output_dir, "train_results.txt")

View File

@@ -39,7 +39,7 @@ from transformers import (
default_data_collator, default_data_collator,
set_seed, set_seed,
) )
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
from utils_qa import postprocess_qa_predictions from utils_qa import postprocess_qa_predictions
@@ -169,16 +169,20 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( # Detecting last checkpoint.
os.path.exists(training_args.output_dir) last_checkpoint = None
and os.listdir(training_args.output_dir) if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
and training_args.do_train last_checkpoint = get_last_checkpoint(training_args.output_dir)
and not training_args.overwrite_output_dir if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
):
raise ValueError( raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty." f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome." "Use --overwrite_output_dir to overcome."
) )
elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
@@ -453,9 +457,13 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
train_result = trainer.train( if last_checkpoint is not None:
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None model_path = last_checkpoint
) elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path
else:
model_path = None
train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") output_train_file = os.path.join(training_args.output_dir, "train_results.txt")

View File

@@ -38,7 +38,7 @@ from transformers import (
default_data_collator, default_data_collator,
set_seed, set_seed,
) )
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
from utils_qa import postprocess_qa_predictions_with_beam_search from utils_qa import postprocess_qa_predictions_with_beam_search
@@ -168,16 +168,20 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( # Detecting last checkpoint.
os.path.exists(training_args.output_dir) last_checkpoint = None
and os.listdir(training_args.output_dir) if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
and training_args.do_train last_checkpoint = get_last_checkpoint(training_args.output_dir)
and not training_args.overwrite_output_dir if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
):
raise ValueError( raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty." f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome." "Use --overwrite_output_dir to overcome."
) )
elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
@@ -492,9 +496,13 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
train_result = trainer.train( if last_checkpoint is not None:
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None model_path = last_checkpoint
) elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path
else:
model_path = None
train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") output_train_file = os.path.join(training_args.output_dir, "train_results.txt")

View File

@@ -40,7 +40,7 @@ from transformers import (
default_data_collator, default_data_collator,
set_seed, set_seed,
) )
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -225,16 +225,20 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( # Detecting last checkpoint.
os.path.exists(training_args.output_dir) last_checkpoint = None
and os.listdir(training_args.output_dir) if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
and training_args.do_train last_checkpoint = get_last_checkpoint(training_args.output_dir)
and not training_args.overwrite_output_dir if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
):
raise ValueError( raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty." f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome." "Use --overwrite_output_dir to overcome."
) )
elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
@@ -481,9 +485,13 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
train_result = trainer.train( if last_checkpoint is not None:
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None model_path = last_checkpoint
) elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path
else:
model_path = None
train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") output_train_file = os.path.join(training_args.output_dir, "train_results.txt")

View File

@@ -38,7 +38,7 @@ from transformers import (
default_data_collator, default_data_collator,
set_seed, set_seed,
) )
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
task_to_keys = { task_to_keys = {
@@ -160,16 +160,20 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( # Detecting last checkpoint.
os.path.exists(training_args.output_dir) last_checkpoint = None
and os.listdir(training_args.output_dir) if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
and training_args.do_train last_checkpoint = get_last_checkpoint(training_args.output_dir)
and not training_args.overwrite_output_dir if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
):
raise ValueError( raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. " f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome." "Use --overwrite_output_dir to overcome."
) )
elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
@@ -385,9 +389,13 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
train_result = trainer.train( if last_checkpoint is not None:
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None model_path = last_checkpoint
) elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path
else:
model_path = None
train_result = trainer.train(model_path=model_path)
metrics = train_result.metrics metrics = train_result.metrics
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload

View File

@@ -39,7 +39,7 @@ from transformers import (
TrainingArguments, TrainingArguments,
set_seed, set_seed,
) )
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -154,16 +154,20 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( # Detecting last checkpoint.
os.path.exists(training_args.output_dir) last_checkpoint = None
and os.listdir(training_args.output_dir) if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
and training_args.do_train last_checkpoint = get_last_checkpoint(training_args.output_dir)
and not training_args.overwrite_output_dir if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
):
raise ValueError( raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty." f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome." "Use --overwrite_output_dir to overcome."
) )
elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
@@ -374,9 +378,13 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
train_result = trainer.train( if last_checkpoint is not None:
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None model_path = last_checkpoint
) elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path
else:
model_path = None
train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") output_train_file = os.path.join(training_args.output_dir, "train_results.txt")

View File

@@ -17,7 +17,9 @@ Utilities for the Trainer and TFTrainer class. Should be independent from PyTorc
""" """
import copy import copy
import os
import random import random
import re
import time import time
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
@@ -75,6 +77,15 @@ class TrainOutput(NamedTuple):
PREFIX_CHECKPOINT_DIR = "checkpoint" PREFIX_CHECKPOINT_DIR = "checkpoint"
_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d)+$")
def get_last_checkpoint(folder):
content = os.listdir(folder)
checkpoints = [path for path in content if _re_checkpoint.search(path) is not None and os.path.isdir(path)]
if len(checkpoints) == 0:
return
return max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0]))
class EvaluationStrategy(ExplicitEnum): class EvaluationStrategy(ExplicitEnum):

View File

@@ -39,7 +39,7 @@ from transformers import (
default_data_collator, default_data_collator,
set_seed, set_seed,
) )
from transformers.trainer_utils import is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -168,16 +168,20 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( # Detecting last checkpoint.
os.path.exists(training_args.output_dir) last_checkpoint = None
and os.listdir(training_args.output_dir) if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
and training_args.do_train last_checkpoint = get_last_checkpoint(training_args.output_dir)
and not training_args.overwrite_output_dir if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
):
raise ValueError( raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty." f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome." "Use --overwrite_output_dir to overcome."
) )
elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
@@ -334,17 +338,21 @@ def main():
# Training # Training
if training_args.do_train: if training_args.do_train:
{%- if cookiecutter.can_train_from_scratch == "False" %} {%- if cookiecutter.can_train_from_scratch == "False" %}
train_result = trainer.train( if last_checkpoint is not None:
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None model_path = last_checkpoint
) elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path
else:
model_path = None
{%- elif cookiecutter.can_train_from_scratch == "True" %} {%- elif cookiecutter.can_train_from_scratch == "True" %}
model_path = ( if last_checkpoint is not None:
model_args.model_name_or_path model_path = last_checkpoint
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
else None model_path = model_args.model_name_or_path
) else:
train_result = trainer.train(model_path=model_path) model_path = None
{% endif %} {% endif %}
train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt") output_train_file = os.path.join(training_args.output_dir, "train_results.txt")