From 37aeb5787a4a0bac2362397002bf5b92161a35d9 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 13 Jul 2022 12:43:08 -0400 Subject: [PATCH] Make sharded checkpoints work in offline mode (#18125) * Make sharded checkpoints work in offline mode * Add test --- src/transformers/utils/hub.py | 5 ++-- tests/utils/test_offline.py | 52 ++++++++++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 6de0b3a246..cb400329d3 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -552,8 +552,9 @@ def get_from_cache( # the models might've been found if local_files_only=False # Notify the user about that if local_files_only: - raise FileNotFoundError( - "Cannot find the requested files in the cached path and outgoing traffic has been" + fname = url.split("/")[-1] + raise EntryNotFoundError( + f"Cannot find the requested file ({fname}) in the cached path and outgoing traffic has been" " disabled. To enable model look-ups and downloads online, set 'local_files_only'" " to False." ) diff --git a/tests/utils/test_offline.py b/tests/utils/test_offline.py index 33f5d4bd0a..0636a4399e 100644 --- a/tests/utils/test_offline.py +++ b/tests/utils/test_offline.py @@ -34,7 +34,7 @@ from transformers import BertConfig, BertModel, BertTokenizer """ run = """ -mname = "lysandre/tiny-bert-random" +mname = "hf-internal-testing/tiny-random-bert" BertConfig.from_pretrained(mname) BertModel.from_pretrained(mname) BertTokenizer.from_pretrained(mname) @@ -69,3 +69,53 @@ socket.socket = offline_socket result = subprocess.run(cmd, env=env, check=False, capture_output=True) self.assertEqual(result.returncode, 0, result.stderr) self.assertIn("success", result.stdout.decode()) + + @require_torch + def test_offline_mode_sharded_checkpoint(self): + + # this test is a bit tricky since TRANSFORMERS_OFFLINE can only be changed before + # `transformers` is loaded, and it's too late for inside pytest - so we are changing it + # while running an external program + + # python one-liner segments + + # this must be loaded before socket.socket is monkey-patched + load = """ +from transformers import BertConfig, BertModel, BertTokenizer + """ + + run = """ +mname = "hf-internal-testing/tiny-random-bert-sharded" +BertConfig.from_pretrained(mname) +BertModel.from_pretrained(mname) +print("success") + """ + + mock = """ +import socket +def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled") +socket.socket = offline_socket + """ + + # baseline - just load from_pretrained with normal network + cmd = [sys.executable, "-c", "\n".join([load, run])] + + # should succeed + env = self.get_env() + result = subprocess.run(cmd, env=env, check=False, capture_output=True) + self.assertEqual(result.returncode, 0, result.stderr) + self.assertIn("success", result.stdout.decode()) + + # next emulate no network + cmd = [sys.executable, "-c", "\n".join([load, mock, run])] + + # Doesn't fail anymore since the model is in the cache due to other tests, so commenting this. + # env["TRANSFORMERS_OFFLINE"] = "0" + # result = subprocess.run(cmd, env=env, check=False, capture_output=True) + # self.assertEqual(result.returncode, 1, result.stderr) + + # should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files + env["TRANSFORMERS_OFFLINE"] = "1" + result = subprocess.run(cmd, env=env, check=False, capture_output=True) + self.assertEqual(result.returncode, 0, result.stderr) + self.assertIn("success", result.stdout.decode())