Rewrite push_to_hub to use upload_files (#18366)

* Rewrite push_to_hub to use upload_files

* Adapt the doc a bit

* Address review comments and clean doc
This commit is contained in:
Sylvain Gugger
2022-08-01 12:07:30 -04:00
committed by GitHub
parent 3909d7f139
commit 01db72abd4
18 changed files with 555 additions and 527 deletions

View File

@@ -2077,15 +2077,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
filename_prefix: (`str`, *optional*):
A prefix to add to the names of the files saved by the tokenizer.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face model hub after saving it.
<Tip warning={true}>
Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
folder. Pass along `temp_dir=True` to use a temporary directory instead.
</Tip>
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
Returns:
A tuple of `str`: The files saved.
@@ -2094,11 +2090,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
os.makedirs(save_directory, exist_ok=True)
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo = self._create_or_get_repo(save_directory, **kwargs)
os.makedirs(save_directory, exist_ok=True)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)
special_tokens_map_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + SPECIAL_TOKENS_MAP_FILE
@@ -2167,8 +2165,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
)
if push_to_hub:
url = self._push_to_hub(repo, commit_message=commit_message)
logger.info(f"Tokenizer pushed to the hub in this commit: {url}")
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
)
return save_files