From 25ddd91b249014d818fb2ed3d4ba856ed9a5653e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 17 Jan 2023 15:24:40 +0100 Subject: [PATCH] Fixing offline mode for pipeline (when inferring task). (#21113) * Fixing offline mode for pipeline (when inferring task). * Update src/transformers/pipelines/__init__.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Updating test to reflect change in exception. * Fixing offline mode. * Clean. Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/pipelines/__init__.py | 3 + tests/utils/test_offline.py | 84 ++++++++++++++++++++++---- 2 files changed, 75 insertions(+), 12 deletions(-) diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 8b06009a4c..94eb67b90a 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -40,6 +40,7 @@ from ..tokenization_utils_fast import PreTrainedTokenizerFast from ..utils import ( HUGGINGFACE_CO_RESOLVE_ENDPOINT, is_kenlm_available, + is_offline_mode, is_pyctcdecode_available, is_tf_available, is_torch_available, @@ -398,6 +399,8 @@ def get_supported_tasks() -> List[str]: def get_task(model: str, use_auth_token: Optional[str] = None) -> str: + if is_offline_mode(): + raise RuntimeError(f"You cannot infer task automatically within `pipeline` when using offline mode") try: info = model_info(model, token=use_auth_token) except Exception as e: diff --git a/tests/utils/test_offline.py b/tests/utils/test_offline.py index 0636a4399e..708accc7a6 100644 --- a/tests/utils/test_offline.py +++ b/tests/utils/test_offline.py @@ -15,6 +15,7 @@ import subprocess import sys +from transformers import BertConfig, BertModel, BertTokenizer, pipeline from transformers.testing_utils import TestCasePlus, require_torch @@ -30,7 +31,7 @@ class OfflineTests(TestCasePlus): # this must be loaded before socket.socket is monkey-patched load = """ -from transformers import BertConfig, BertModel, BertTokenizer +from transformers import BertConfig, BertModel, BertTokenizer, pipeline """ run = """ @@ -38,34 +39,69 @@ mname = "hf-internal-testing/tiny-random-bert" BertConfig.from_pretrained(mname) BertModel.from_pretrained(mname) BertTokenizer.from_pretrained(mname) +pipe = pipeline(task="fill-mask", model=mname) print("success") """ mock = """ import socket -def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled") +def offline_socket(*args, **kwargs): raise RuntimeError("Offline mode is enabled, we shouldn't access internet") socket.socket = offline_socket """ + # Force fetching the files so that we can use the cache + mname = "hf-internal-testing/tiny-random-bert" + BertConfig.from_pretrained(mname) + BertModel.from_pretrained(mname) + BertTokenizer.from_pretrained(mname) + pipeline(task="fill-mask", model=mname) + # baseline - just load from_pretrained with normal network - cmd = [sys.executable, "-c", "\n".join([load, run])] + cmd = [sys.executable, "-c", "\n".join([load, run, mock])] # should succeed env = self.get_env() + # should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files + env["TRANSFORMERS_OFFLINE"] = "1" result = subprocess.run(cmd, env=env, check=False, capture_output=True) self.assertEqual(result.returncode, 0, result.stderr) self.assertIn("success", result.stdout.decode()) - # next emulate no network - cmd = [sys.executable, "-c", "\n".join([load, mock, run])] + @require_torch + def test_offline_mode_no_internet(self): + # python one-liner segments + # this must be loaded before socket.socket is monkey-patched + load = """ +from transformers import BertConfig, BertModel, BertTokenizer, pipeline + """ - # Doesn't fail anymore since the model is in the cache due to other tests, so commenting this. - # env["TRANSFORMERS_OFFLINE"] = "0" - # result = subprocess.run(cmd, env=env, check=False, capture_output=True) - # self.assertEqual(result.returncode, 1, result.stderr) + run = """ +mname = "hf-internal-testing/tiny-random-bert" +BertConfig.from_pretrained(mname) +BertModel.from_pretrained(mname) +BertTokenizer.from_pretrained(mname) +pipe = pipeline(task="fill-mask", model=mname) +print("success") + """ - # should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files - env["TRANSFORMERS_OFFLINE"] = "1" + mock = """ +import socket +def offline_socket(*args, **kwargs): raise socket.error("Faking flaky internet") +socket.socket = offline_socket + """ + + # Force fetching the files so that we can use the cache + mname = "hf-internal-testing/tiny-random-bert" + BertConfig.from_pretrained(mname) + BertModel.from_pretrained(mname) + BertTokenizer.from_pretrained(mname) + pipeline(task="fill-mask", model=mname) + + # baseline - just load from_pretrained with normal network + cmd = [sys.executable, "-c", "\n".join([load, run, mock])] + + # should succeed + env = self.get_env() result = subprocess.run(cmd, env=env, check=False, capture_output=True) self.assertEqual(result.returncode, 0, result.stderr) self.assertIn("success", result.stdout.decode()) @@ -93,7 +129,7 @@ print("success") mock = """ import socket -def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled") +def offline_socket(*args, **kwargs): raise ValueError("Offline mode is enabled") socket.socket = offline_socket """ @@ -119,3 +155,27 @@ socket.socket = offline_socket result = subprocess.run(cmd, env=env, check=False, capture_output=True) self.assertEqual(result.returncode, 0, result.stderr) self.assertIn("success", result.stdout.decode()) + + @require_torch + def test_offline_mode_pipeline_exception(self): + load = """ +from transformers import pipeline + """ + run = """ +mname = "hf-internal-testing/tiny-random-bert" +pipe = pipeline(model=mname) + """ + + mock = """ +import socket +def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled") +socket.socket = offline_socket + """ + env = self.get_env() + env["TRANSFORMERS_OFFLINE"] = "1" + cmd = [sys.executable, "-c", "\n".join([load, mock, run])] + result = subprocess.run(cmd, env=env, check=False, capture_output=True) + self.assertEqual(result.returncode, 1, result.stderr) + self.assertIn( + "You cannot infer task automatically within `pipeline` when using offline mode", result.stderr.decode() + )