Add push_to_hub to no_trainer examples (#13659)

* Add push_to_hub to no_trainer examples

* Quality

* Document integration

* Roll out to other examples
This commit is contained in:
Sylvain Gugger
2021-09-21 13:13:30 -04:00
committed by GitHub
parent a722c301bf
commit b7d264be0d
10 changed files with 319 additions and 27 deletions

View File

@@ -18,6 +18,7 @@ import logging
import math
import os
import random
from pathlib import Path
import datasets
from datasets import load_dataset, load_metric
@@ -26,6 +27,7 @@ from tqdm.auto import tqdm
import transformers
from accelerate import Accelerator
from huggingface_hub import Repository
from transformers import (
AdamW,
AutoConfig,
@@ -38,6 +40,7 @@ from transformers import (
get_scheduler,
set_seed,
)
from transformers.file_utils import get_full_repo_name
from transformers.utils.versions import require_version
@@ -142,6 +145,11 @@ def parse_args():
)
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument(
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
)
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
args = parser.parse_args()
# Sanity checks
@@ -155,8 +163,8 @@ def parse_args():
extension = args.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
return args
@@ -188,6 +196,18 @@ def main():
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name)
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
@@ -426,10 +446,22 @@ def main():
eval_metric = metric.compute()
logger.info(f"epoch {epoch}: {eval_metric}")
if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir)
repo.push_to_hub(commit_message=f"Training in progress epoch {epoch}", blocking=False)
if args.output_dir is not None:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training")
if args.task_name == "mnli":
# Final evaluation on mismatched validation set