Fix .push_to_hub and cleanup get_full_repo_name usage (#25120)

* Fix .push_to_hub and cleanup get_full_repo_name usage

* Do not rely on Python bool conversion magic

* request changes
This commit is contained in:
Lucain
2023-07-28 11:40:08 +02:00
committed by GitHub
parent 400e76ef11
commit 6232c380f2
31 changed files with 266 additions and 241 deletions

View File

@@ -29,7 +29,7 @@ import datasets
import torch
from accelerate import Accelerator, DistributedDataParallelKwargs
from datasets import ClassLabel, load_dataset, load_metric
from huggingface_hub import Repository
from huggingface_hub import Repository, create_repo
from luke_utils import DataCollatorForLukeTokenClassification, is_punctuation, padding_tensor
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
@@ -45,7 +45,6 @@ from transformers import (
get_scheduler,
set_seed,
)
from transformers.file_utils import get_full_repo_name
from transformers.utils.versions import require_version
@@ -258,11 +257,14 @@ def main():
# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name)
# Retrieve of infer repo_name
repo_name = args.hub_model_id
if repo_name is None:
repo_name = Path(args.output_dir).absolute().name
# Create repo and retrieve repo_id
repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
# Clone repo locally
repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()