Model versioning (#8324)

* fix typo

* rm use_cdn & references, and implement new hf_bucket_url

* I'm pretty sure we don't need to `read` this file

* same here

* [BIG] file_utils.networking: do not gobble up errors anymore

* Fix CI 😇

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Tiny doc tweak

* Add doc + pass kwarg everywhere

* Add more tests and explain

cc @sshleifer let me know if better

Co-Authored-By: Sam Shleifer <sshleifer@gmail.com>

* Also implement revision in pipelines

In the case where we're passing a task name or a string model identifier

* Fix CI 😇

* Fix CI

* [hf_api] new methods + command line implem

* make style

* Final endpoints post-migration

* Fix post-migration

* Py3.6 compat

cc @stefan-it

Thank you @stas00

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Julien Chaumond
2020-11-10 13:11:02 +01:00
committed by GitHub
parent 4185b115d4
commit 70f622fab4
23 changed files with 472 additions and 210 deletions

63
tests/test_file_utils.py Normal file
View File

@@ -0,0 +1,63 @@
import unittest
import requests
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME, filename_to_url, get_from_cache, hf_bucket_url
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER
MODEL_ID = DUMMY_UNKWOWN_IDENTIFIER
# An actual model hosted on huggingface.co
REVISION_ID_DEFAULT = "main"
# Default branch name
REVISION_ID_ONE_SPECIFIC_COMMIT = "f2c752cfc5c0ab6f4bdec59acea69eefbee381c2"
# One particular commit (not the top of `main`)
REVISION_ID_INVALID = "aaaaaaa"
# This commit does not exist, so we should 404.
PINNED_SHA1 = "d9e9f15bc825e4b2c9249e9578f884bbcb5e3684"
# Sha-1 of config.json on the top of `main`, for checking purposes
PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d3"
# Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes
class GetFromCacheTests(unittest.TestCase):
def test_bogus_url(self):
# This lets us simulate no connection
# as the error raised is the same
# `ConnectionError`
url = "https://bogus"
with self.assertRaisesRegex(ValueError, "Connection error"):
_ = get_from_cache(url)
def test_file_not_found(self):
# Valid revision (None) but missing file.
url = hf_bucket_url(MODEL_ID, filename="missing.bin")
with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"):
_ = get_from_cache(url)
def test_revision_not_found(self):
# Valid file but missing revision
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID)
with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"):
_ = get_from_cache(url)
def test_standard_object(self):
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT)
filepath = get_from_cache(url, force_download=True)
metadata = filename_to_url(filepath)
self.assertEqual(metadata, (url, f'"{PINNED_SHA1}"'))
def test_standard_object_rev(self):
# Same object, but different revision
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_ONE_SPECIFIC_COMMIT)
filepath = get_from_cache(url, force_download=True)
metadata = filename_to_url(filepath)
self.assertNotEqual(metadata[1], f'"{PINNED_SHA1}"')
# Caution: check that the etag is *not* equal to the one from `test_standard_object`
def test_lfs_object(self):
url = hf_bucket_url(MODEL_ID, filename=WEIGHTS_NAME, revision=REVISION_ID_DEFAULT)
filepath = get_from_cache(url, force_download=True)
metadata = filename_to_url(filepath)
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))

View File

@@ -20,7 +20,7 @@ import unittest
import requests
from requests.exceptions import HTTPError
from transformers.hf_api import HfApi, HfFolder, ModelInfo, PresignedUrl, S3Obj
from transformers.hf_api import HfApi, HfFolder, ModelInfo, PresignedUrl, RepoObj, S3Obj
USER = "__DUMMY_TRANSFORMERS_USER__"
@@ -35,6 +35,7 @@ FILES = [
os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/empty.txt"),
),
]
REPO_NAME = "my-model-{}".format(int(time.time()))
ENDPOINT_STAGING = "https://moon-staging.huggingface.co"
@@ -78,15 +79,6 @@ class HfApiEndpointsTest(HfApiCommonTest):
urls = self._api.presign(token=self._token, filename="nested/valid_org.txt", organization="valid_org")
self.assertIsInstance(urls, PresignedUrl)
def test_presign_invalid(self):
try:
_ = self._api.presign(token=self._token, filename="non_nested.json")
except HTTPError as e:
self.assertIsNotNone(e.response.text)
self.assertTrue("Filename invalid" in e.response.text)
else:
self.fail("Expected an exception")
def test_presign(self):
for FILE_KEY, FILE_PATH in FILES:
urls = self._api.presign(token=self._token, filename=FILE_KEY)
@@ -109,6 +101,17 @@ class HfApiEndpointsTest(HfApiCommonTest):
o = objs[-1]
self.assertIsInstance(o, S3Obj)
def test_list_repos_objs(self):
objs = self._api.list_repos_objs(token=self._token)
self.assertIsInstance(objs, list)
if len(objs) > 0:
o = objs[-1]
self.assertIsInstance(o, RepoObj)
def test_create_and_delete_repo(self):
self._api.create_repo(token=self._token, name=REPO_NAME)
self._api.delete_repo(token=self._token, name=REPO_NAME)
class HfApiPublicTest(unittest.TestCase):
def test_staging_model_list(self):

View File

@@ -323,7 +323,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
def test_custom_load_tf_weights(self):
model, output_loading_info = TFBertForTokenClassification.from_pretrained(
"jplu/tiny-tf-bert-random", use_cdn=False, output_loading_info=True
"jplu/tiny-tf-bert-random", output_loading_info=True
)
self.assertEqual(sorted(output_loading_info["unexpected_keys"]), ["mlm___cls", "nsp___cls"])
for layer in output_loading_info["missing_keys"]: