Add push_to_hub to no_trainer examples (#13659)
* Add push_to_hub to no_trainer examples * Quality * Document integration * Roll out to other examples
This commit is contained in:
@@ -24,6 +24,7 @@ import math
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import datasets
|
||||
@@ -34,6 +35,7 @@ from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from huggingface_hub import Repository
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
@@ -47,7 +49,7 @@ from transformers import (
|
||||
get_scheduler,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.file_utils import PaddingStrategy
|
||||
from transformers.file_utils import PaddingStrategy, get_full_repo_name
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -169,9 +171,15 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Activate debug mode and run training only with a subset of data.",
|
||||
)
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
|
||||
)
|
||||
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
|
||||
args = parser.parse_args()
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
|
||||
|
||||
return args
|
||||
|
||||
@@ -260,6 +268,18 @@ def main():
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub:
|
||||
if args.hub_model_id is None:
|
||||
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||
else:
|
||||
repo_name = args.hub_model_id
|
||||
repo = Repository(args.output_dir, clone_from=repo_name)
|
||||
elif args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||
@@ -478,10 +498,22 @@ def main():
|
||||
eval_metric = metric.compute()
|
||||
accelerator.print(f"epoch {epoch}: {eval_metric}")
|
||||
|
||||
if args.push_to_hub and epoch < args.num_train_epochs - 1:
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
|
||||
if accelerator.is_main_process:
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
repo.push_to_hub(commit_message=f"Training in progress epoch {epoch}", blocking=False)
|
||||
|
||||
if args.output_dir is not None:
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
|
||||
if accelerator.is_main_process:
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user