Fix .push_to_hub and cleanup get_full_repo_name usage (#25120)

* Fix .push_to_hub and cleanup get_full_repo_name usage

* Do not rely on Python bool conversion magic

* request changes
This commit is contained in:
Lucain
2023-07-28 11:40:08 +02:00
committed by GitHub
parent 400e76ef11
commit 6232c380f2
31 changed files with 266 additions and 241 deletions

View File

@@ -43,7 +43,7 @@ from transformers import (
set_seed,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
from transformers.utils import get_full_repo_name, send_example_telemetry
from transformers.utils import send_example_telemetry
logger = get_logger(__name__)
@@ -418,12 +418,14 @@ def main():
# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub and not args.preprocessing_only:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
# Retrieve of infer repo_name
repo_name = args.hub_model_id
if repo_name is None:
repo_name = Path(args.output_dir).absolute().name
# Create repo and retrieve repo_id
repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
# Clone repo locally
repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()