Add push_to_hub to no_trainer examples (#13659)
* Add push_to_hub to no_trainer examples * Quality * Document integration * Roll out to other examples
This commit is contained in:
@@ -23,6 +23,7 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
@@ -32,6 +33,7 @@ from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from huggingface_hub import Repository
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
@@ -45,6 +47,7 @@ from transformers import (
|
||||
get_scheduler,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.file_utils import get_full_repo_name
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
@@ -195,6 +198,11 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Activate debug mode and run training only with a subset of data.",
|
||||
)
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
|
||||
)
|
||||
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Sanity checks
|
||||
@@ -208,8 +216,8 @@ def parse_args():
|
||||
extension = args.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
if args.push_to_hub:
|
||||
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
|
||||
|
||||
return args
|
||||
|
||||
@@ -241,6 +249,18 @@ def main():
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub:
|
||||
if args.hub_model_id is None:
|
||||
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||
else:
|
||||
repo_name = args.hub_model_id
|
||||
repo = Repository(args.output_dir, clone_from=repo_name)
|
||||
elif args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
||||
# or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/
|
||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||
@@ -552,10 +572,22 @@ def main():
|
||||
eval_metric = compute_metrics()
|
||||
accelerator.print(f"epoch {epoch}:", eval_metric)
|
||||
|
||||
if args.push_to_hub and epoch < args.num_train_epochs - 1:
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
|
||||
if accelerator.is_main_process:
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
repo.push_to_hub(commit_message=f"Training in progress epoch {epoch}", blocking=False)
|
||||
|
||||
if args.output_dir is not None:
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
|
||||
if accelerator.is_main_process:
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user