Deprecate model_path in Trainer.train (#9854)

This commit is contained in:
Sylvain Gugger
2021-01-28 08:32:46 -05:00
committed by GitHub
parent 2ee9f9b69e
commit b4e559cfa1
14 changed files with 96 additions and 78 deletions

View File

@@ -341,20 +341,20 @@ def main():
if training_args.do_train:
{%- if cookiecutter.can_train_from_scratch == "False" %}
if last_checkpoint is not None:
model_path = last_checkpoint
checkpoint = last_checkpoint
elif os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path
checkpoint = model_args.model_name_or_path
else:
model_path = None
checkpoint = None
{%- elif cookiecutter.can_train_from_scratch == "True" %}
if last_checkpoint is not None:
model_path = last_checkpoint
checkpoint = last_checkpoint
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
model_path = model_args.model_name_or_path
checkpoint = model_args.model_name_or_path
else:
model_path = None
checkpoint = None
{% endif %}
train_result = trainer.train(model_path=model_path)
train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")