Add push_to_hub to no_trainer examples (#13659)
* Add push_to_hub to no_trainer examples * Quality * Document integration * Roll out to other examples
This commit is contained in:
@@ -18,6 +18,7 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
from datasets import load_dataset, load_metric
|
||||
@@ -26,6 +27,7 @@ from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from huggingface_hub import Repository
|
||||
from transformers import (
|
||||
AdamW,
|
||||
AutoConfig,
|
||||
@@ -38,6 +40,7 @@ from transformers import (
|
||||
get_scheduler,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.file_utils import get_full_repo_name
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
@@ -142,6 +145,11 @@ def parse_args():
|
||||
)
|
||||
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
|
||||
)
|
||||
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Sanity checks
|
||||
@@ -155,8 +163,8 @@ def parse_args():
|
||||
extension = args.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
if args.push_to_hub:
|
||||
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
|
||||
|
||||
return args
|
||||
|
||||
@@ -188,6 +196,18 @@ def main():
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub:
|
||||
if args.hub_model_id is None:
|
||||
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||
else:
|
||||
repo_name = args.hub_model_id
|
||||
repo = Repository(args.output_dir, clone_from=repo_name)
|
||||
elif args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
|
||||
# or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
|
||||
|
||||
@@ -426,10 +446,22 @@ def main():
|
||||
eval_metric = metric.compute()
|
||||
logger.info(f"epoch {epoch}: {eval_metric}")
|
||||
|
||||
if args.push_to_hub and epoch < args.num_train_epochs - 1:
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
|
||||
if accelerator.is_main_process:
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
repo.push_to_hub(commit_message=f"Training in progress epoch {epoch}", blocking=False)
|
||||
|
||||
if args.output_dir is not None:
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
|
||||
if accelerator.is_main_process:
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training")
|
||||
|
||||
if args.task_name == "mnli":
|
||||
# Final evaluation on mismatched validation set
|
||||
|
||||
Reference in New Issue
Block a user