From a32f97c37d53cfc4299e83678c4082e3ddb00bf9 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 26 Sep 2022 18:01:00 -0400 Subject: [PATCH] Fix cached_file in offline mode for cached non-existing files (#19206) * Fix cached_file in offline mode for cached non-existing files * Add tests * Test with offline mode --- src/transformers/utils/hub.py | 4 +- tests/utils/test_file_utils.py | 48 +------------- tests/utils/test_hub_utils.py | 110 +++++++++++++++++++++++++++++++++ utils/tests_fetcher.py | 2 +- 4 files changed, 114 insertions(+), 50 deletions(-) create mode 100644 tests/utils/test_hub_utils.py diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index cd4c92e50b..8c149bec64 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -435,7 +435,7 @@ def cached_file( except LocalEntryNotFoundError: # We try to see if we have a cached version (not up to date): resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision) - if resolved_file is not None: + if resolved_file is not None and resolved_file != _CACHED_NO_EXIST: return resolved_file if not _raise_exceptions_for_missing_entries or not _raise_exceptions_for_connection_errors: return None @@ -457,7 +457,7 @@ def cached_file( except HTTPError as err: # First we try to see if we have a cached version (not up to date): resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision) - if resolved_file is not None: + if resolved_file is not None and resolved_file != _CACHED_NO_EXIST: return resolved_file if not _raise_exceptions_for_connection_errors: return None diff --git a/tests/utils/test_file_utils.py b/tests/utils/test_file_utils.py index 60676e9f7d..e7963bfa51 100644 --- a/tests/utils/test_file_utils.py +++ b/tests/utils/test_file_utils.py @@ -15,28 +15,14 @@ import contextlib import importlib import io -import json -import tempfile import unittest -from pathlib import Path import transformers # Try to import everything from transformers to ensure every object can be loaded. from transformers import * # noqa F406 from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER -from transformers.utils import ( - FLAX_WEIGHTS_NAME, - TF2_WEIGHTS_NAME, - WEIGHTS_NAME, - ContextManagers, - find_labels, - get_file_from_repo, - has_file, - is_flax_available, - is_tf_available, - is_torch_available, -) +from transformers.utils import ContextManagers, find_labels, is_flax_available, is_tf_available, is_torch_available MODEL_ID = DUMMY_UNKNOWN_IDENTIFIER @@ -77,38 +63,6 @@ class TestImportMechanisms(unittest.TestCase): assert importlib.util.find_spec("transformers") is not None -class GetFromCacheTests(unittest.TestCase): - def test_has_file(self): - self.assertTrue(has_file("hf-internal-testing/tiny-bert-pt-only", WEIGHTS_NAME)) - self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", TF2_WEIGHTS_NAME)) - self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", FLAX_WEIGHTS_NAME)) - - def test_get_file_from_repo_distant(self): - # `get_file_from_repo` returns None if the file does not exist - self.assertIsNone(get_file_from_repo("bert-base-cased", "ahah.txt")) - - # The function raises if the repository does not exist. - with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"): - get_file_from_repo("bert-base-case", "config.json") - - # The function raises if the revision does not exist. - with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"): - get_file_from_repo("bert-base-cased", "config.json", revision="ahaha") - - resolved_file = get_file_from_repo("bert-base-cased", "config.json") - # The name is the cached name which is not very easy to test, so instead we load the content. - config = json.loads(open(resolved_file, "r").read()) - self.assertEqual(config["hidden_size"], 768) - - def test_get_file_from_repo_local(self): - with tempfile.TemporaryDirectory() as tmp_dir: - filename = Path(tmp_dir) / "a.txt" - filename.touch() - self.assertEqual(get_file_from_repo(tmp_dir, "a.txt"), str(filename)) - - self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt")) - - class GenericUtilTests(unittest.TestCase): @unittest.mock.patch("sys.stdout", new_callable=io.StringIO) def test_context_managers_no_context(self, mock_stdout): diff --git a/tests/utils/test_hub_utils.py b/tests/utils/test_hub_utils.py new file mode 100644 index 0000000000..f55a0ae431 --- /dev/null +++ b/tests/utils/test_hub_utils.py @@ -0,0 +1,110 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import tempfile +import unittest +from pathlib import Path + +from transformers.utils import ( + CONFIG_NAME, + FLAX_WEIGHTS_NAME, + TF2_WEIGHTS_NAME, + TRANSFORMERS_CACHE, + WEIGHTS_NAME, + cached_file, + get_file_from_repo, + has_file, +) + + +RANDOM_BERT = "hf-internal-testing/tiny-random-bert" +CACHE_DIR = os.path.join(TRANSFORMERS_CACHE, "models--hf-internal-testing--tiny-random-bert") +FULL_COMMIT_HASH = "9b8c223d42b2188cb49d29af482996f9d0f3e5a6" + + +class GetFromCacheTests(unittest.TestCase): + def test_cached_file(self): + archive_file = cached_file(RANDOM_BERT, CONFIG_NAME) + # Should have downloaded the file in here + self.assertTrue(os.path.isdir(CACHE_DIR)) + # Cache should contain at least those three subfolders: + for subfolder in ["blobs", "refs", "snapshots"]: + self.assertTrue(os.path.isdir(os.path.join(CACHE_DIR, subfolder))) + with open(os.path.join(CACHE_DIR, "refs", "main")) as f: + main_commit = f.read() + self.assertEqual(archive_file, os.path.join(CACHE_DIR, "snapshots", main_commit, CONFIG_NAME)) + self.assertTrue(os.path.isfile(archive_file)) + + # File is cached at the same place the second time. + new_archive_file = cached_file(RANDOM_BERT, CONFIG_NAME) + self.assertEqual(archive_file, new_archive_file) + + # Using a specific revision to test the full commit hash. + archive_file = cached_file(RANDOM_BERT, CONFIG_NAME, revision="9b8c223") + self.assertEqual(archive_file, os.path.join(CACHE_DIR, "snapshots", FULL_COMMIT_HASH, CONFIG_NAME)) + + def test_cached_file_errors(self): + with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"): + _ = cached_file("tiny-random-bert", CONFIG_NAME) + + with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"): + _ = cached_file(RANDOM_BERT, CONFIG_NAME, revision="aaaa") + + with self.assertRaisesRegex(EnvironmentError, "does not appear to have a file named"): + _ = cached_file(RANDOM_BERT, "conf") + + def test_non_existence_is_cached(self): + with self.assertRaisesRegex(EnvironmentError, "does not appear to have a file named"): + _ = cached_file(RANDOM_BERT, "conf") + + with open(os.path.join(CACHE_DIR, "refs", "main")) as f: + main_commit = f.read() + self.assertTrue(os.path.isfile(os.path.join(CACHE_DIR, ".no_exist", main_commit, "conf"))) + + path = cached_file(RANDOM_BERT, "conf", _raise_exceptions_for_missing_entries=False) + self.assertIsNone(path) + + path = cached_file(RANDOM_BERT, "conf", local_files_only=True, _raise_exceptions_for_missing_entries=False) + self.assertIsNone(path) + + def test_has_file(self): + self.assertTrue(has_file("hf-internal-testing/tiny-bert-pt-only", WEIGHTS_NAME)) + self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", TF2_WEIGHTS_NAME)) + self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", FLAX_WEIGHTS_NAME)) + + def test_get_file_from_repo_distant(self): + # `get_file_from_repo` returns None if the file does not exist + self.assertIsNone(get_file_from_repo("bert-base-cased", "ahah.txt")) + + # The function raises if the repository does not exist. + with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"): + get_file_from_repo("bert-base-case", CONFIG_NAME) + + # The function raises if the revision does not exist. + with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"): + get_file_from_repo("bert-base-cased", CONFIG_NAME, revision="ahaha") + + resolved_file = get_file_from_repo("bert-base-cased", CONFIG_NAME) + # The name is the cached name which is not very easy to test, so instead we load the content. + config = json.loads(open(resolved_file, "r").read()) + self.assertEqual(config["hidden_size"], 768) + + def test_get_file_from_repo_local(self): + with tempfile.TemporaryDirectory() as tmp_dir: + filename = Path(tmp_dir) / "a.txt" + filename.touch() + self.assertEqual(get_file_from_repo(tmp_dir, "a.txt"), str(filename)) + + self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt")) diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index 167bf75db1..0af1a8ad8e 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -354,7 +354,7 @@ SPECIAL_MODULE_TO_TEST_MAP = { "feature_extraction_utils.py": "test_feature_extraction_common.py", "file_utils.py": ["utils/test_file_utils.py", "utils/test_model_output.py"], "utils/generic.py": ["utils/test_file_utils.py", "utils/test_model_output.py", "utils/test_generic.py"], - "utils/hub.py": "utils/test_file_utils.py", + "utils/hub.py": "utils/test_hub_utils.py", "modelcard.py": "utils/test_model_card.py", "modeling_flax_utils.py": "test_modeling_flax_common.py", "modeling_tf_utils.py": ["test_modeling_tf_common.py", "utils/test_modeling_tf_core.py"],