Auto-resume training from checkpoint (#9776)

* Auto-resume training from checkpoint

* Update examples/text-classification/run_glue.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Roll out to other examples

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Sylvain Gugger
2021-01-25 12:03:51 -05:00
committed by GitHub
parent 0f443436fb
commit caf4abf768
12 changed files with 255 additions and 168 deletions

View File

@@ -39,7 +39,7 @@ from transformers import (
default_data_collator,
set_seed,
)
from transformers.trainer_utils import is_main_process
from transformers.trainer_utils import get_last_checkpoint, is_main_process
logger = logging.getLogger(__name__)
@@ -168,16 +168,20 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty."
"Use --overwrite_output_dir to overcome."
)
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging
logging.basicConfig(
@@ -334,17 +338,21 @@ def main():
# Training
if training_args.do_train:
{%- if cookiecutter.can_train_from_scratch == "False" %}
train_result = trainer.train(
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
)
if last_checkpoint is not None:
model_path = last_checkpoint
elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path
else:
model_path = None
{%- elif cookiecutter.can_train_from_scratch == "True" %}
model_path = (
model_args.model_name_or_path
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path))
else None
)
train_result = trainer.train(model_path=model_path)
if last_checkpoint is not None:
model_path = last_checkpoint
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path
else:
model_path = None
{% endif %}
train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")