Fixing offline mode for pipeline (when inferring task). (#21113)
* Fixing offline mode for pipeline (when inferring task). * Update src/transformers/pipelines/__init__.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Updating test to reflect change in exception. * Fixing offline mode. * Clean. Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -40,6 +40,7 @@ from ..tokenization_utils_fast import PreTrainedTokenizerFast
|
|||||||
from ..utils import (
|
from ..utils import (
|
||||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||||
is_kenlm_available,
|
is_kenlm_available,
|
||||||
|
is_offline_mode,
|
||||||
is_pyctcdecode_available,
|
is_pyctcdecode_available,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@@ -398,6 +399,8 @@ def get_supported_tasks() -> List[str]:
|
|||||||
|
|
||||||
|
|
||||||
def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
|
def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
|
||||||
|
if is_offline_mode():
|
||||||
|
raise RuntimeError(f"You cannot infer task automatically within `pipeline` when using offline mode")
|
||||||
try:
|
try:
|
||||||
info = model_info(model, token=use_auth_token)
|
info = model_info(model, token=use_auth_token)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from transformers import BertConfig, BertModel, BertTokenizer, pipeline
|
||||||
from transformers.testing_utils import TestCasePlus, require_torch
|
from transformers.testing_utils import TestCasePlus, require_torch
|
||||||
|
|
||||||
|
|
||||||
@@ -30,7 +31,7 @@ class OfflineTests(TestCasePlus):
|
|||||||
|
|
||||||
# this must be loaded before socket.socket is monkey-patched
|
# this must be loaded before socket.socket is monkey-patched
|
||||||
load = """
|
load = """
|
||||||
from transformers import BertConfig, BertModel, BertTokenizer
|
from transformers import BertConfig, BertModel, BertTokenizer, pipeline
|
||||||
"""
|
"""
|
||||||
|
|
||||||
run = """
|
run = """
|
||||||
@@ -38,34 +39,69 @@ mname = "hf-internal-testing/tiny-random-bert"
|
|||||||
BertConfig.from_pretrained(mname)
|
BertConfig.from_pretrained(mname)
|
||||||
BertModel.from_pretrained(mname)
|
BertModel.from_pretrained(mname)
|
||||||
BertTokenizer.from_pretrained(mname)
|
BertTokenizer.from_pretrained(mname)
|
||||||
|
pipe = pipeline(task="fill-mask", model=mname)
|
||||||
print("success")
|
print("success")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
mock = """
|
mock = """
|
||||||
import socket
|
import socket
|
||||||
def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled")
|
def offline_socket(*args, **kwargs): raise RuntimeError("Offline mode is enabled, we shouldn't access internet")
|
||||||
socket.socket = offline_socket
|
socket.socket = offline_socket
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Force fetching the files so that we can use the cache
|
||||||
|
mname = "hf-internal-testing/tiny-random-bert"
|
||||||
|
BertConfig.from_pretrained(mname)
|
||||||
|
BertModel.from_pretrained(mname)
|
||||||
|
BertTokenizer.from_pretrained(mname)
|
||||||
|
pipeline(task="fill-mask", model=mname)
|
||||||
|
|
||||||
# baseline - just load from_pretrained with normal network
|
# baseline - just load from_pretrained with normal network
|
||||||
cmd = [sys.executable, "-c", "\n".join([load, run])]
|
cmd = [sys.executable, "-c", "\n".join([load, run, mock])]
|
||||||
|
|
||||||
# should succeed
|
# should succeed
|
||||||
env = self.get_env()
|
env = self.get_env()
|
||||||
|
# should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files
|
||||||
|
env["TRANSFORMERS_OFFLINE"] = "1"
|
||||||
result = subprocess.run(cmd, env=env, check=False, capture_output=True)
|
result = subprocess.run(cmd, env=env, check=False, capture_output=True)
|
||||||
self.assertEqual(result.returncode, 0, result.stderr)
|
self.assertEqual(result.returncode, 0, result.stderr)
|
||||||
self.assertIn("success", result.stdout.decode())
|
self.assertIn("success", result.stdout.decode())
|
||||||
|
|
||||||
# next emulate no network
|
@require_torch
|
||||||
cmd = [sys.executable, "-c", "\n".join([load, mock, run])]
|
def test_offline_mode_no_internet(self):
|
||||||
|
# python one-liner segments
|
||||||
|
# this must be loaded before socket.socket is monkey-patched
|
||||||
|
load = """
|
||||||
|
from transformers import BertConfig, BertModel, BertTokenizer, pipeline
|
||||||
|
"""
|
||||||
|
|
||||||
# Doesn't fail anymore since the model is in the cache due to other tests, so commenting this.
|
run = """
|
||||||
# env["TRANSFORMERS_OFFLINE"] = "0"
|
mname = "hf-internal-testing/tiny-random-bert"
|
||||||
# result = subprocess.run(cmd, env=env, check=False, capture_output=True)
|
BertConfig.from_pretrained(mname)
|
||||||
# self.assertEqual(result.returncode, 1, result.stderr)
|
BertModel.from_pretrained(mname)
|
||||||
|
BertTokenizer.from_pretrained(mname)
|
||||||
|
pipe = pipeline(task="fill-mask", model=mname)
|
||||||
|
print("success")
|
||||||
|
"""
|
||||||
|
|
||||||
# should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files
|
mock = """
|
||||||
env["TRANSFORMERS_OFFLINE"] = "1"
|
import socket
|
||||||
|
def offline_socket(*args, **kwargs): raise socket.error("Faking flaky internet")
|
||||||
|
socket.socket = offline_socket
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Force fetching the files so that we can use the cache
|
||||||
|
mname = "hf-internal-testing/tiny-random-bert"
|
||||||
|
BertConfig.from_pretrained(mname)
|
||||||
|
BertModel.from_pretrained(mname)
|
||||||
|
BertTokenizer.from_pretrained(mname)
|
||||||
|
pipeline(task="fill-mask", model=mname)
|
||||||
|
|
||||||
|
# baseline - just load from_pretrained with normal network
|
||||||
|
cmd = [sys.executable, "-c", "\n".join([load, run, mock])]
|
||||||
|
|
||||||
|
# should succeed
|
||||||
|
env = self.get_env()
|
||||||
result = subprocess.run(cmd, env=env, check=False, capture_output=True)
|
result = subprocess.run(cmd, env=env, check=False, capture_output=True)
|
||||||
self.assertEqual(result.returncode, 0, result.stderr)
|
self.assertEqual(result.returncode, 0, result.stderr)
|
||||||
self.assertIn("success", result.stdout.decode())
|
self.assertIn("success", result.stdout.decode())
|
||||||
@@ -93,7 +129,7 @@ print("success")
|
|||||||
|
|
||||||
mock = """
|
mock = """
|
||||||
import socket
|
import socket
|
||||||
def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled")
|
def offline_socket(*args, **kwargs): raise ValueError("Offline mode is enabled")
|
||||||
socket.socket = offline_socket
|
socket.socket = offline_socket
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -119,3 +155,27 @@ socket.socket = offline_socket
|
|||||||
result = subprocess.run(cmd, env=env, check=False, capture_output=True)
|
result = subprocess.run(cmd, env=env, check=False, capture_output=True)
|
||||||
self.assertEqual(result.returncode, 0, result.stderr)
|
self.assertEqual(result.returncode, 0, result.stderr)
|
||||||
self.assertIn("success", result.stdout.decode())
|
self.assertIn("success", result.stdout.decode())
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_offline_mode_pipeline_exception(self):
|
||||||
|
load = """
|
||||||
|
from transformers import pipeline
|
||||||
|
"""
|
||||||
|
run = """
|
||||||
|
mname = "hf-internal-testing/tiny-random-bert"
|
||||||
|
pipe = pipeline(model=mname)
|
||||||
|
"""
|
||||||
|
|
||||||
|
mock = """
|
||||||
|
import socket
|
||||||
|
def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled")
|
||||||
|
socket.socket = offline_socket
|
||||||
|
"""
|
||||||
|
env = self.get_env()
|
||||||
|
env["TRANSFORMERS_OFFLINE"] = "1"
|
||||||
|
cmd = [sys.executable, "-c", "\n".join([load, mock, run])]
|
||||||
|
result = subprocess.run(cmd, env=env, check=False, capture_output=True)
|
||||||
|
self.assertEqual(result.returncode, 1, result.stderr)
|
||||||
|
self.assertIn(
|
||||||
|
"You cannot infer task automatically within `pipeline` when using offline mode", result.stderr.decode()
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user