offline mode for firewalled envs (part 2) (#10569)

* more readable test

* add all the missing places

* one more nltk

* better exception check

* revert
This commit is contained in:
Stas Bekman
2021-03-08 08:52:20 -08:00
committed by GitHub
parent 5469369480
commit 6f84531e61
5 changed files with 42 additions and 6 deletions

View File

@@ -27,20 +27,37 @@ class OfflineTests(TestCasePlus):
# while running an external program
# python one-liner segments
load = "from transformers import BertConfig, BertModel, BertTokenizer;"
run = "mname = 'lysandre/tiny-bert-random'; BertConfig.from_pretrained(mname) and BertModel.from_pretrained(mname) and BertTokenizer.from_pretrained(mname);"
mock = 'import socket; exec("def offline_socket(*args, **kwargs): raise socket.error(\\"Offline mode is enabled.\\")"); socket.socket = offline_socket;'
# this must be loaded before socket.socket is monkey-patched
load = """
from transformers import BertConfig, BertModel, BertTokenizer
"""
run = """
mname = "lysandre/tiny-bert-random"
BertConfig.from_pretrained(mname)
BertModel.from_pretrained(mname)
BertTokenizer.from_pretrained(mname)
print("success")
"""
mock = """
import socket
def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled")
socket.socket = offline_socket
"""
# baseline - just load from_pretrained with normal network
cmd = [sys.executable, "-c", f"{load} {run}"]
cmd = [sys.executable, "-c", "\n".join([load, run])]
# should succeed
env = self.get_env()
result = subprocess.run(cmd, env=env, check=False, capture_output=True)
self.assertEqual(result.returncode, 0, result.stderr)
self.assertIn("success", result.stdout.decode())
# next emulate no network
cmd = [sys.executable, "-c", f"{load} {mock} {run}"]
cmd = [sys.executable, "-c", "\n".join([load, mock, run])]
# should normally fail as it will fail to lookup the model files w/o the network
env["TRANSFORMERS_OFFLINE"] = "0"
@@ -51,3 +68,4 @@ class OfflineTests(TestCasePlus):
env["TRANSFORMERS_OFFLINE"] = "1"
result = subprocess.run(cmd, env=env, check=False, capture_output=True)
self.assertEqual(result.returncode, 0, result.stderr)
self.assertIn("success", result.stdout.decode())