Add test when downloading from gated repo (#25039)
This commit is contained in:
@@ -36,6 +36,9 @@ RANDOM_BERT = "hf-internal-testing/tiny-random-bert"
|
|||||||
CACHE_DIR = os.path.join(TRANSFORMERS_CACHE, "models--hf-internal-testing--tiny-random-bert")
|
CACHE_DIR = os.path.join(TRANSFORMERS_CACHE, "models--hf-internal-testing--tiny-random-bert")
|
||||||
FULL_COMMIT_HASH = "9b8c223d42b2188cb49d29af482996f9d0f3e5a6"
|
FULL_COMMIT_HASH = "9b8c223d42b2188cb49d29af482996f9d0f3e5a6"
|
||||||
|
|
||||||
|
GATED_REPO = "hf-internal-testing/dummy-gated-model"
|
||||||
|
README_FILE = "README.md"
|
||||||
|
|
||||||
|
|
||||||
class GetFromCacheTests(unittest.TestCase):
|
class GetFromCacheTests(unittest.TestCase):
|
||||||
def test_cached_file(self):
|
def test_cached_file(self):
|
||||||
@@ -124,3 +127,13 @@ class GetFromCacheTests(unittest.TestCase):
|
|||||||
self.assertEqual(get_file_from_repo(tmp_dir, "a.txt"), str(filename))
|
self.assertEqual(get_file_from_repo(tmp_dir, "a.txt"), str(filename))
|
||||||
|
|
||||||
self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt"))
|
self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt"))
|
||||||
|
|
||||||
|
def test_get_file_gated_repo(self):
|
||||||
|
"""Test download file from a gated repo fails with correct message when not authenticated."""
|
||||||
|
with self.assertRaisesRegex(EnvironmentError, "You are trying to access a gated repo."):
|
||||||
|
cached_file(GATED_REPO, README_FILE, use_auth_token=False)
|
||||||
|
|
||||||
|
def test_has_file_gated_repo(self):
|
||||||
|
"""Test check file existence from a gated repo fails with correct message when not authenticated."""
|
||||||
|
with self.assertRaisesRegex(EnvironmentError, "is a gated repository"):
|
||||||
|
has_file(GATED_REPO, README_FILE, use_auth_token=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user