🐛 Properly raise RepoNotFoundError when not authenticated (#17651)
* Raise RepoNotFoundError in case of 401 * Include changes from revert-17646-skip_repo_not_found * Add a comment * 💄 Code quality * 💚 Update `get_from_cache` test * 💚 Code quality & skip failing test
This commit is contained in:
@@ -38,6 +38,7 @@ import requests
|
||||
from filelock import FileLock
|
||||
from huggingface_hub import HfFolder, Repository, create_repo, list_repo_files, whoami
|
||||
from requests.exceptions import HTTPError
|
||||
from requests.models import Response
|
||||
from transformers.utils.logging import tqdm
|
||||
|
||||
from . import __version__, logging
|
||||
@@ -398,20 +399,27 @@ class RevisionNotFoundError(HTTPError):
|
||||
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
|
||||
|
||||
|
||||
def _raise_for_status(request):
|
||||
def _raise_for_status(response: Response):
|
||||
"""
|
||||
Internal version of `request.raise_for_status()` that will refine a potential HTTPError.
|
||||
"""
|
||||
if "X-Error-Code" in request.headers:
|
||||
error_code = request.headers["X-Error-Code"]
|
||||
if "X-Error-Code" in response.headers:
|
||||
error_code = response.headers["X-Error-Code"]
|
||||
if error_code == "RepoNotFound":
|
||||
raise RepositoryNotFoundError(f"404 Client Error: Repository Not Found for url: {request.url}")
|
||||
raise RepositoryNotFoundError(f"404 Client Error: Repository Not Found for url: {response.url}")
|
||||
elif error_code == "EntryNotFound":
|
||||
raise EntryNotFoundError(f"404 Client Error: Entry Not Found for url: {request.url}")
|
||||
raise EntryNotFoundError(f"404 Client Error: Entry Not Found for url: {response.url}")
|
||||
elif error_code == "RevisionNotFound":
|
||||
raise RevisionNotFoundError(f"404 Client Error: Revision Not Found for url: {request.url}")
|
||||
raise RevisionNotFoundError(f"404 Client Error: Revision Not Found for url: {response.url}")
|
||||
|
||||
request.raise_for_status()
|
||||
if response.status_code == 401:
|
||||
# The repo was not found and the user is not Authenticated
|
||||
raise RepositoryNotFoundError(
|
||||
f"401 Client Error: Repository not found for url: {response.url}. "
|
||||
"If the repo is private, make sure you are authenticated."
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
|
||||
|
||||
@@ -88,7 +88,6 @@ class AutoConfigTest(unittest.TestCase):
|
||||
if "custom" in CONFIG_MAPPING._extra_content:
|
||||
del CONFIG_MAPPING._extra_content["custom"]
|
||||
|
||||
@unittest.skip("Temp bug in the Hub not returning RepoNotFound errors.")
|
||||
def test_repo_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
|
||||
|
||||
@@ -76,7 +76,6 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
||||
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG)
|
||||
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
||||
|
||||
@unittest.skip("Temp bug in the Hub not returning RepoNotFound errors.")
|
||||
def test_repo_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
|
||||
|
||||
@@ -328,7 +328,6 @@ class AutoModelTest(unittest.TestCase):
|
||||
if CustomConfig in mapping._extra_content:
|
||||
del mapping._extra_content[CustomConfig]
|
||||
|
||||
@unittest.skip("Temp bug in the Hub not returning RepoNotFound errors.")
|
||||
def test_repo_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
|
||||
|
||||
@@ -77,7 +77,6 @@ class FlaxAutoModelTest(unittest.TestCase):
|
||||
|
||||
eval(**tokens).block_until_ready()
|
||||
|
||||
@unittest.skip("Temp bug in the Hub not returning RepoNotFound errors.")
|
||||
def test_repo_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
|
||||
|
||||
@@ -265,7 +265,6 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
if NewModelConfig in mapping._extra_content:
|
||||
del mapping._extra_content[NewModelConfig]
|
||||
|
||||
@unittest.skip("Temp bug in the Hub not returning RepoNotFound errors.")
|
||||
def test_repo_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
|
||||
|
||||
@@ -142,7 +142,6 @@ class AutoTokenizerTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(tokenizer.model_max_length, 512)
|
||||
|
||||
@unittest.skip("Temp bug in the Hub not returning RepoNotFound errors.")
|
||||
@require_tokenizers
|
||||
def test_tokenizer_identifier_non_existent(self):
|
||||
for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]:
|
||||
@@ -330,7 +329,6 @@ class AutoTokenizerTest(unittest.TestCase):
|
||||
else:
|
||||
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")
|
||||
|
||||
@unittest.skip("Temp bug in the Hub not returning RepoNotFound errors.")
|
||||
def test_repo_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
|
||||
|
||||
@@ -99,12 +99,19 @@ class GetFromCacheTests(unittest.TestCase):
|
||||
with self.assertRaisesRegex(EntryNotFoundError, "404 Client Error"):
|
||||
_ = get_from_cache(url)
|
||||
|
||||
@unittest.skip("Temp bug in the Hub not returning RepoNotFound errors.")
|
||||
def test_model_not_found(self):
|
||||
# Invalid model file.
|
||||
def test_model_not_found_not_authenticated(self):
|
||||
# Invalid model id.
|
||||
url = hf_bucket_url("bert-base", filename="pytorch_model.bin")
|
||||
with self.assertRaisesRegex(RepositoryNotFoundError, "401 Client Error"):
|
||||
_ = get_from_cache(url)
|
||||
|
||||
@unittest.skip("No authentication when testing against prod")
|
||||
def test_model_not_found_authenticated(self):
|
||||
# Invalid model id.
|
||||
url = hf_bucket_url("bert-base", filename="pytorch_model.bin")
|
||||
with self.assertRaisesRegex(RepositoryNotFoundError, "404 Client Error"):
|
||||
_ = get_from_cache(url)
|
||||
_ = get_from_cache(url, use_auth_token="hf_sometoken")
|
||||
# ^ TODO - if we decide to unskip this: use a real / functional token
|
||||
|
||||
def test_revision_not_found(self):
|
||||
# Valid file but missing revision
|
||||
@@ -142,9 +149,8 @@ class GetFromCacheTests(unittest.TestCase):
|
||||
self.assertIsNone(get_file_from_repo("bert-base-cased", "ahah.txt"))
|
||||
|
||||
# The function raises if the repository does not exist.
|
||||
# Uncomment when bug is fixed.
|
||||
# with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"):
|
||||
# get_file_from_repo("bert-base-case", "config.json")
|
||||
with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"):
|
||||
get_file_from_repo("bert-base-case", "config.json")
|
||||
|
||||
# The function raises if the revision does not exist.
|
||||
with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"):
|
||||
|
||||
Reference in New Issue
Block a user