Model versioning (#8324)
* fix typo * rm use_cdn & references, and implement new hf_bucket_url * I'm pretty sure we don't need to `read` this file * same here * [BIG] file_utils.networking: do not gobble up errors anymore * Fix CI 😇 * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Tiny doc tweak * Add doc + pass kwarg everywhere * Add more tests and explain cc @sshleifer let me know if better Co-Authored-By: Sam Shleifer <sshleifer@gmail.com> * Also implement revision in pipelines In the case where we're passing a task name or a string model identifier * Fix CI 😇 * Fix CI * [hf_api] new methods + command line implem * make style * Final endpoints post-migration * Fix post-migration * Py3.6 compat cc @stefan-it Thank you @stas00 Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
63
tests/test_file_utils.py
Normal file
63
tests/test_file_utils.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME, filename_to_url, get_from_cache, hf_bucket_url
|
||||
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER
|
||||
|
||||
|
||||
MODEL_ID = DUMMY_UNKWOWN_IDENTIFIER
|
||||
# An actual model hosted on huggingface.co
|
||||
|
||||
REVISION_ID_DEFAULT = "main"
|
||||
# Default branch name
|
||||
REVISION_ID_ONE_SPECIFIC_COMMIT = "f2c752cfc5c0ab6f4bdec59acea69eefbee381c2"
|
||||
# One particular commit (not the top of `main`)
|
||||
REVISION_ID_INVALID = "aaaaaaa"
|
||||
# This commit does not exist, so we should 404.
|
||||
|
||||
PINNED_SHA1 = "d9e9f15bc825e4b2c9249e9578f884bbcb5e3684"
|
||||
# Sha-1 of config.json on the top of `main`, for checking purposes
|
||||
PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d3"
|
||||
# Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes
|
||||
|
||||
|
||||
class GetFromCacheTests(unittest.TestCase):
|
||||
def test_bogus_url(self):
|
||||
# This lets us simulate no connection
|
||||
# as the error raised is the same
|
||||
# `ConnectionError`
|
||||
url = "https://bogus"
|
||||
with self.assertRaisesRegex(ValueError, "Connection error"):
|
||||
_ = get_from_cache(url)
|
||||
|
||||
def test_file_not_found(self):
|
||||
# Valid revision (None) but missing file.
|
||||
url = hf_bucket_url(MODEL_ID, filename="missing.bin")
|
||||
with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"):
|
||||
_ = get_from_cache(url)
|
||||
|
||||
def test_revision_not_found(self):
|
||||
# Valid file but missing revision
|
||||
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID)
|
||||
with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"):
|
||||
_ = get_from_cache(url)
|
||||
|
||||
def test_standard_object(self):
|
||||
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT)
|
||||
filepath = get_from_cache(url, force_download=True)
|
||||
metadata = filename_to_url(filepath)
|
||||
self.assertEqual(metadata, (url, f'"{PINNED_SHA1}"'))
|
||||
|
||||
def test_standard_object_rev(self):
|
||||
# Same object, but different revision
|
||||
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_ONE_SPECIFIC_COMMIT)
|
||||
filepath = get_from_cache(url, force_download=True)
|
||||
metadata = filename_to_url(filepath)
|
||||
self.assertNotEqual(metadata[1], f'"{PINNED_SHA1}"')
|
||||
# Caution: check that the etag is *not* equal to the one from `test_standard_object`
|
||||
|
||||
def test_lfs_object(self):
|
||||
url = hf_bucket_url(MODEL_ID, filename=WEIGHTS_NAME, revision=REVISION_ID_DEFAULT)
|
||||
filepath = get_from_cache(url, force_download=True)
|
||||
metadata = filename_to_url(filepath)
|
||||
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
|
||||
@@ -20,7 +20,7 @@ import unittest
|
||||
|
||||
import requests
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers.hf_api import HfApi, HfFolder, ModelInfo, PresignedUrl, S3Obj
|
||||
from transformers.hf_api import HfApi, HfFolder, ModelInfo, PresignedUrl, RepoObj, S3Obj
|
||||
|
||||
|
||||
USER = "__DUMMY_TRANSFORMERS_USER__"
|
||||
@@ -35,6 +35,7 @@ FILES = [
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/empty.txt"),
|
||||
),
|
||||
]
|
||||
REPO_NAME = "my-model-{}".format(int(time.time()))
|
||||
ENDPOINT_STAGING = "https://moon-staging.huggingface.co"
|
||||
|
||||
|
||||
@@ -78,15 +79,6 @@ class HfApiEndpointsTest(HfApiCommonTest):
|
||||
urls = self._api.presign(token=self._token, filename="nested/valid_org.txt", organization="valid_org")
|
||||
self.assertIsInstance(urls, PresignedUrl)
|
||||
|
||||
def test_presign_invalid(self):
|
||||
try:
|
||||
_ = self._api.presign(token=self._token, filename="non_nested.json")
|
||||
except HTTPError as e:
|
||||
self.assertIsNotNone(e.response.text)
|
||||
self.assertTrue("Filename invalid" in e.response.text)
|
||||
else:
|
||||
self.fail("Expected an exception")
|
||||
|
||||
def test_presign(self):
|
||||
for FILE_KEY, FILE_PATH in FILES:
|
||||
urls = self._api.presign(token=self._token, filename=FILE_KEY)
|
||||
@@ -109,6 +101,17 @@ class HfApiEndpointsTest(HfApiCommonTest):
|
||||
o = objs[-1]
|
||||
self.assertIsInstance(o, S3Obj)
|
||||
|
||||
def test_list_repos_objs(self):
|
||||
objs = self._api.list_repos_objs(token=self._token)
|
||||
self.assertIsInstance(objs, list)
|
||||
if len(objs) > 0:
|
||||
o = objs[-1]
|
||||
self.assertIsInstance(o, RepoObj)
|
||||
|
||||
def test_create_and_delete_repo(self):
|
||||
self._api.create_repo(token=self._token, name=REPO_NAME)
|
||||
self._api.delete_repo(token=self._token, name=REPO_NAME)
|
||||
|
||||
|
||||
class HfApiPublicTest(unittest.TestCase):
|
||||
def test_staging_model_list(self):
|
||||
|
||||
@@ -323,7 +323,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
def test_custom_load_tf_weights(self):
|
||||
model, output_loading_info = TFBertForTokenClassification.from_pretrained(
|
||||
"jplu/tiny-tf-bert-random", use_cdn=False, output_loading_info=True
|
||||
"jplu/tiny-tf-bert-random", output_loading_info=True
|
||||
)
|
||||
self.assertEqual(sorted(output_loading_info["unexpected_keys"]), ["mlm___cls", "nsp___cls"])
|
||||
for layer in output_loading_info["missing_keys"]:
|
||||
|
||||
Reference in New Issue
Block a user