Fixing offline mode for pipeline (when inferring task). (#21113)

* Fixing offline mode for pipeline (when inferring task).

* Update src/transformers/pipelines/__init__.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Updating test to reflect change in exception.

* Fixing offline mode.

* Clean.

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2023-01-17 15:24:40 +01:00
committed by GitHub
parent 8896ebb9a9
commit 25ddd91b24
2 changed files with 75 additions and 12 deletions

View File

@@ -40,6 +40,7 @@ from ..tokenization_utils_fast import PreTrainedTokenizerFast
from ..utils import ( from ..utils import (
HUGGINGFACE_CO_RESOLVE_ENDPOINT, HUGGINGFACE_CO_RESOLVE_ENDPOINT,
is_kenlm_available, is_kenlm_available,
is_offline_mode,
is_pyctcdecode_available, is_pyctcdecode_available,
is_tf_available, is_tf_available,
is_torch_available, is_torch_available,
@@ -398,6 +399,8 @@ def get_supported_tasks() -> List[str]:
def get_task(model: str, use_auth_token: Optional[str] = None) -> str: def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
if is_offline_mode():
raise RuntimeError(f"You cannot infer task automatically within `pipeline` when using offline mode")
try: try:
info = model_info(model, token=use_auth_token) info = model_info(model, token=use_auth_token)
except Exception as e: except Exception as e:

View File

@@ -15,6 +15,7 @@
import subprocess import subprocess
import sys import sys
from transformers import BertConfig, BertModel, BertTokenizer, pipeline
from transformers.testing_utils import TestCasePlus, require_torch from transformers.testing_utils import TestCasePlus, require_torch
@@ -30,7 +31,7 @@ class OfflineTests(TestCasePlus):
# this must be loaded before socket.socket is monkey-patched # this must be loaded before socket.socket is monkey-patched
load = """ load = """
from transformers import BertConfig, BertModel, BertTokenizer from transformers import BertConfig, BertModel, BertTokenizer, pipeline
""" """
run = """ run = """
@@ -38,34 +39,69 @@ mname = "hf-internal-testing/tiny-random-bert"
BertConfig.from_pretrained(mname) BertConfig.from_pretrained(mname)
BertModel.from_pretrained(mname) BertModel.from_pretrained(mname)
BertTokenizer.from_pretrained(mname) BertTokenizer.from_pretrained(mname)
pipe = pipeline(task="fill-mask", model=mname)
print("success") print("success")
""" """
mock = """ mock = """
import socket import socket
def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled") def offline_socket(*args, **kwargs): raise RuntimeError("Offline mode is enabled, we shouldn't access internet")
socket.socket = offline_socket socket.socket = offline_socket
""" """
# Force fetching the files so that we can use the cache
mname = "hf-internal-testing/tiny-random-bert"
BertConfig.from_pretrained(mname)
BertModel.from_pretrained(mname)
BertTokenizer.from_pretrained(mname)
pipeline(task="fill-mask", model=mname)
# baseline - just load from_pretrained with normal network # baseline - just load from_pretrained with normal network
cmd = [sys.executable, "-c", "\n".join([load, run])] cmd = [sys.executable, "-c", "\n".join([load, run, mock])]
# should succeed # should succeed
env = self.get_env() env = self.get_env()
# should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files
env["TRANSFORMERS_OFFLINE"] = "1"
result = subprocess.run(cmd, env=env, check=False, capture_output=True) result = subprocess.run(cmd, env=env, check=False, capture_output=True)
self.assertEqual(result.returncode, 0, result.stderr) self.assertEqual(result.returncode, 0, result.stderr)
self.assertIn("success", result.stdout.decode()) self.assertIn("success", result.stdout.decode())
# next emulate no network @require_torch
cmd = [sys.executable, "-c", "\n".join([load, mock, run])] def test_offline_mode_no_internet(self):
# python one-liner segments
# this must be loaded before socket.socket is monkey-patched
load = """
from transformers import BertConfig, BertModel, BertTokenizer, pipeline
"""
# Doesn't fail anymore since the model is in the cache due to other tests, so commenting this. run = """
# env["TRANSFORMERS_OFFLINE"] = "0" mname = "hf-internal-testing/tiny-random-bert"
# result = subprocess.run(cmd, env=env, check=False, capture_output=True) BertConfig.from_pretrained(mname)
# self.assertEqual(result.returncode, 1, result.stderr) BertModel.from_pretrained(mname)
BertTokenizer.from_pretrained(mname)
pipe = pipeline(task="fill-mask", model=mname)
print("success")
"""
# should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files mock = """
env["TRANSFORMERS_OFFLINE"] = "1" import socket
def offline_socket(*args, **kwargs): raise socket.error("Faking flaky internet")
socket.socket = offline_socket
"""
# Force fetching the files so that we can use the cache
mname = "hf-internal-testing/tiny-random-bert"
BertConfig.from_pretrained(mname)
BertModel.from_pretrained(mname)
BertTokenizer.from_pretrained(mname)
pipeline(task="fill-mask", model=mname)
# baseline - just load from_pretrained with normal network
cmd = [sys.executable, "-c", "\n".join([load, run, mock])]
# should succeed
env = self.get_env()
result = subprocess.run(cmd, env=env, check=False, capture_output=True) result = subprocess.run(cmd, env=env, check=False, capture_output=True)
self.assertEqual(result.returncode, 0, result.stderr) self.assertEqual(result.returncode, 0, result.stderr)
self.assertIn("success", result.stdout.decode()) self.assertIn("success", result.stdout.decode())
@@ -93,7 +129,7 @@ print("success")
mock = """ mock = """
import socket import socket
def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled") def offline_socket(*args, **kwargs): raise ValueError("Offline mode is enabled")
socket.socket = offline_socket socket.socket = offline_socket
""" """
@@ -119,3 +155,27 @@ socket.socket = offline_socket
result = subprocess.run(cmd, env=env, check=False, capture_output=True) result = subprocess.run(cmd, env=env, check=False, capture_output=True)
self.assertEqual(result.returncode, 0, result.stderr) self.assertEqual(result.returncode, 0, result.stderr)
self.assertIn("success", result.stdout.decode()) self.assertIn("success", result.stdout.decode())
@require_torch
def test_offline_mode_pipeline_exception(self):
load = """
from transformers import pipeline
"""
run = """
mname = "hf-internal-testing/tiny-random-bert"
pipe = pipeline(model=mname)
"""
mock = """
import socket
def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled")
socket.socket = offline_socket
"""
env = self.get_env()
env["TRANSFORMERS_OFFLINE"] = "1"
cmd = [sys.executable, "-c", "\n".join([load, mock, run])]
result = subprocess.run(cmd, env=env, check=False, capture_output=True)
self.assertEqual(result.returncode, 1, result.stderr)
self.assertIn(
"You cannot infer task automatically within `pipeline` when using offline mode", result.stderr.decode()
)