Clean push to hub API (#12187)

* Clean push to hub API

* Create working dir if it does not exist

* Different tweak

* New API + all models + test Flax

* Adds the Trainer clean up

* Update src/transformers/file_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Address review comments

* (nit) output types

* No need to set clone_from when folder exists

* Update src/transformers/trainer.py

Co-authored-by: Julien Chaumond <julien@huggingface.co>

* Add generated_from_trainer tag

* Update to new version

* Fixes

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Julien Chaumond <julien@huggingface.co>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Sylvain Gugger
2021-06-23 10:11:19 -04:00
committed by GitHub
parent 625f512d5e
commit 53c60babe4
17 changed files with 368 additions and 159 deletions

View File

@@ -1884,6 +1884,15 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
value error is raised.
filename_prefix: (:obj:`str`, `optional`):
A prefix to add to the names of the files saved by the tokenizer.
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to push your model to the Hugging Face model hub after saving it.
.. warning::
Using :obj:`push_to_hub=True` will synchronize the repository you are pushing to with
:obj:`save_directory`, which requires :obj:`save_directory` to be a local clone of the repo you are
pushing to if it's an existing folder. Pass along :obj:`temp_dir=True` to use a temporary directory
instead.
Returns:
A tuple of :obj:`str`: The files saved.
@@ -1891,6 +1900,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo = self._create_or_get_repo(save_directory, **kwargs)
os.makedirs(save_directory, exist_ok=True)
special_tokens_map_file = os.path.join(
@@ -1949,9 +1963,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
)
if push_to_hub:
# Annoyingly, the return contains files that don't exist.
existing_files = [f for f in save_files if os.path.isfile(f)]
url = self._push_to_hub(save_files=existing_files, **kwargs)
url = self._push_to_hub(repo, commit_message=commit_message)
logger.info(f"Tokenizer pushed to the hub in this commit: {url}")
return save_files