Fix .push_to_hub and cleanup get_full_repo_name usage (#25120)

* Fix .push_to_hub and cleanup get_full_repo_name usage

* Do not rely on Python bool conversion magic

* request changes
This commit is contained in:
Lucain
2023-07-28 11:40:08 +02:00
committed by GitHub
parent 400e76ef11
commit 6232c380f2
31 changed files with 266 additions and 241 deletions

View File

@@ -54,7 +54,7 @@ from transformers import (
is_tensorboard_available,
set_seed,
)
from transformers.utils import get_full_repo_name, send_example_telemetry
from transformers.utils import send_example_telemetry
logger = logging.getLogger(__name__)
@@ -293,14 +293,14 @@ def main():
# Handle the repository creation
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
)
else:
repo_name = training_args.hub_model_id
create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# Retrieve of infer repo_name
repo_name = training_args.hub_model_id
if repo_name is None:
repo_name = Path(training_args.output_dir).absolute().name
# Create repo and retrieve repo_id
repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
# Clone repo locally
repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
# Initialize datasets and pre-processing transforms
# We use torchvision here for faster pre-processing