Clean up hub (#18497)
* Clean up utils.hub * Remove imports * More fixes * Last fix
This commit is contained in:
@@ -26,20 +26,13 @@ import transformers
|
||||
from transformers import * # noqa F406
|
||||
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
|
||||
from transformers.utils import (
|
||||
CONFIG_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
TF2_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
ContextManagers,
|
||||
EntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
filename_to_url,
|
||||
find_labels,
|
||||
get_file_from_repo,
|
||||
get_from_cache,
|
||||
has_file,
|
||||
hf_bucket_url,
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
@@ -85,60 +78,6 @@ class TestImportMechanisms(unittest.TestCase):
|
||||
|
||||
|
||||
class GetFromCacheTests(unittest.TestCase):
|
||||
def test_bogus_url(self):
|
||||
# This lets us simulate no connection
|
||||
# as the error raised is the same
|
||||
# `ConnectionError`
|
||||
url = "https://bogus"
|
||||
with self.assertRaisesRegex(ValueError, "Connection error"):
|
||||
_ = get_from_cache(url)
|
||||
|
||||
def test_file_not_found(self):
|
||||
# Valid revision (None) but missing file.
|
||||
url = hf_bucket_url(MODEL_ID, filename="missing.bin")
|
||||
with self.assertRaisesRegex(EntryNotFoundError, "404 Client Error"):
|
||||
_ = get_from_cache(url)
|
||||
|
||||
def test_model_not_found_not_authenticated(self):
|
||||
# Invalid model id.
|
||||
url = hf_bucket_url("bert-base", filename="pytorch_model.bin")
|
||||
with self.assertRaisesRegex(RepositoryNotFoundError, "401 Client Error"):
|
||||
_ = get_from_cache(url)
|
||||
|
||||
@unittest.skip("No authentication when testing against prod")
|
||||
def test_model_not_found_authenticated(self):
|
||||
# Invalid model id.
|
||||
url = hf_bucket_url("bert-base", filename="pytorch_model.bin")
|
||||
with self.assertRaisesRegex(RepositoryNotFoundError, "404 Client Error"):
|
||||
_ = get_from_cache(url, use_auth_token="hf_sometoken")
|
||||
# ^ TODO - if we decide to unskip this: use a real / functional token
|
||||
|
||||
def test_revision_not_found(self):
|
||||
# Valid file but missing revision
|
||||
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID)
|
||||
with self.assertRaisesRegex(RevisionNotFoundError, "404 Client Error"):
|
||||
_ = get_from_cache(url)
|
||||
|
||||
def test_standard_object(self):
|
||||
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT)
|
||||
filepath = get_from_cache(url, force_download=True)
|
||||
metadata = filename_to_url(filepath)
|
||||
self.assertEqual(metadata, (url, f'"{PINNED_SHA1}"'))
|
||||
|
||||
def test_standard_object_rev(self):
|
||||
# Same object, but different revision
|
||||
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_ONE_SPECIFIC_COMMIT)
|
||||
filepath = get_from_cache(url, force_download=True)
|
||||
metadata = filename_to_url(filepath)
|
||||
self.assertNotEqual(metadata[1], f'"{PINNED_SHA1}"')
|
||||
# Caution: check that the etag is *not* equal to the one from `test_standard_object`
|
||||
|
||||
def test_lfs_object(self):
|
||||
url = hf_bucket_url(MODEL_ID, filename=WEIGHTS_NAME, revision=REVISION_ID_DEFAULT)
|
||||
filepath = get_from_cache(url, force_download=True)
|
||||
metadata = filename_to_url(filepath)
|
||||
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
|
||||
|
||||
def test_has_file(self):
|
||||
self.assertTrue(has_file("hf-internal-testing/tiny-bert-pt-only", WEIGHTS_NAME))
|
||||
self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", TF2_WEIGHTS_NAME))
|
||||
|
||||
Reference in New Issue
Block a user