Clean up hub (#18497)

* Clean up utils.hub

* Remove imports

* More fixes

* Last fix
This commit is contained in:
Sylvain Gugger
2022-08-08 08:48:10 -04:00
committed by GitHub
parent a4562552eb
commit 377cdded7a
14 changed files with 67 additions and 708 deletions

View File

@@ -26,20 +26,13 @@ import transformers
from transformers import * # noqa F406
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
from transformers.utils import (
CONFIG_NAME,
FLAX_WEIGHTS_NAME,
TF2_WEIGHTS_NAME,
WEIGHTS_NAME,
ContextManagers,
EntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
filename_to_url,
find_labels,
get_file_from_repo,
get_from_cache,
has_file,
hf_bucket_url,
is_flax_available,
is_tf_available,
is_torch_available,
@@ -85,60 +78,6 @@ class TestImportMechanisms(unittest.TestCase):
class GetFromCacheTests(unittest.TestCase):
def test_bogus_url(self):
# This lets us simulate no connection
# as the error raised is the same
# `ConnectionError`
url = "https://bogus"
with self.assertRaisesRegex(ValueError, "Connection error"):
_ = get_from_cache(url)
def test_file_not_found(self):
# Valid revision (None) but missing file.
url = hf_bucket_url(MODEL_ID, filename="missing.bin")
with self.assertRaisesRegex(EntryNotFoundError, "404 Client Error"):
_ = get_from_cache(url)
def test_model_not_found_not_authenticated(self):
# Invalid model id.
url = hf_bucket_url("bert-base", filename="pytorch_model.bin")
with self.assertRaisesRegex(RepositoryNotFoundError, "401 Client Error"):
_ = get_from_cache(url)
@unittest.skip("No authentication when testing against prod")
def test_model_not_found_authenticated(self):
# Invalid model id.
url = hf_bucket_url("bert-base", filename="pytorch_model.bin")
with self.assertRaisesRegex(RepositoryNotFoundError, "404 Client Error"):
_ = get_from_cache(url, use_auth_token="hf_sometoken")
# ^ TODO - if we decide to unskip this: use a real / functional token
def test_revision_not_found(self):
# Valid file but missing revision
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID)
with self.assertRaisesRegex(RevisionNotFoundError, "404 Client Error"):
_ = get_from_cache(url)
def test_standard_object(self):
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT)
filepath = get_from_cache(url, force_download=True)
metadata = filename_to_url(filepath)
self.assertEqual(metadata, (url, f'"{PINNED_SHA1}"'))
def test_standard_object_rev(self):
# Same object, but different revision
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_ONE_SPECIFIC_COMMIT)
filepath = get_from_cache(url, force_download=True)
metadata = filename_to_url(filepath)
self.assertNotEqual(metadata[1], f'"{PINNED_SHA1}"')
# Caution: check that the etag is *not* equal to the one from `test_standard_object`
def test_lfs_object(self):
url = hf_bucket_url(MODEL_ID, filename=WEIGHTS_NAME, revision=REVISION_ID_DEFAULT)
filepath = get_from_cache(url, force_download=True)
metadata = filename_to_url(filepath)
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
def test_has_file(self):
self.assertTrue(has_file("hf-internal-testing/tiny-bert-pt-only", WEIGHTS_NAME))
self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", TF2_WEIGHTS_NAME))