Trainer push to hub (#11328)

* Initial support for upload to hub

* push -> upload

* Fixes + examples

* Fix torchhub test

* Torchhub test I hate you

* push_model_to_hub -> push_to_hub

* Apply mixin to other pretrained models

* Remove ABC inheritance

* Add tests

* Typo

* Run tests

* Install git-lfs

* Change approach

* Add push_to_hub to all

* Staging test suite

* Typo

* Maybe like this?

* More deps

* Cache

* Adapt name

* Quality

* MOAR tests

* Put it in testing_utils

* Docs + torchhub last hope

* Styling

* Wrong method

* Typos

* Update src/transformers/file_utils.py

Co-authored-by: Julien Chaumond <julien@huggingface.co>

* Address review comments

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

Co-authored-by: Julien Chaumond <julien@huggingface.co>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sylvain Gugger
2021-04-23 09:17:37 -04:00
committed by GitHub
parent 7bc86bea68
commit bf2e0cf70b
31 changed files with 766 additions and 31 deletions

View File

@@ -20,11 +20,15 @@ import pickle
import re
import shutil
import tempfile
import unittest
from collections import OrderedDict
from itertools import takewhile
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from huggingface_hub import HfApi
from requests.exceptions import HTTPError
from transformers import (
BertTokenizer,
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
@@ -32,8 +36,12 @@ from transformers import (
is_torch_available,
)
from transformers.testing_utils import (
ENDPOINT_STAGING,
PASS,
USER,
get_tests_dir,
is_pt_tf_cross_test,
is_staging_test,
require_tf,
require_tokenizers,
require_torch,
@@ -2863,3 +2871,53 @@ class TokenizerTesterMixin:
)
for key in python_output:
self.assertEqual(python_output[key], rust_output[key])
@is_staging_test
class TokenzierPushToHubTester(unittest.TestCase):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]
@classmethod
def setUpClass(cls):
cls._api = HfApi(endpoint=ENDPOINT_STAGING)
cls._token = cls._api.login(username=USER, password=PASS)
@classmethod
def tearDownClass(cls):
try:
cls._api.delete_repo(token=cls._token, name="test-model")
except HTTPError:
pass
try:
cls._api.delete_repo(token=cls._token, name="test-model-org", organization="valid_org")
except HTTPError:
pass
def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
tokenizer.save_pretrained(tmp_dir, push_to_hub=True, repo_name="test-model", use_auth_token=self._token)
new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-model")
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
def test_push_to_hub_in_organization(self):
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
tokenizer.save_pretrained(
tmp_dir,
push_to_hub=True,
repo_name="test-model-org",
use_auth_token=self._token,
organization="valid_org",
)
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-model-org")
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)