offline mode for firewalled envs (part 2) (#10569)
* more readable test * add all the missing places * one more nltk * better exception check * revert
This commit is contained in:
@@ -35,6 +35,7 @@ from .file_utils import (
|
|||||||
cached_path,
|
cached_path,
|
||||||
hf_bucket_url,
|
hf_bucket_url,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
|
is_offline_mode,
|
||||||
is_remote_url,
|
is_remote_url,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@@ -342,6 +343,10 @@ class PreTrainedFeatureExtractor:
|
|||||||
local_files_only = kwargs.pop("local_files_only", False)
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
revision = kwargs.pop("revision", None)
|
revision = kwargs.pop("revision", None)
|
||||||
|
|
||||||
|
if is_offline_mode() and not local_files_only:
|
||||||
|
logger.info("Offline mode: forcing local_files_only=True")
|
||||||
|
local_files_only = True
|
||||||
|
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
|
feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
|
||||||
|
|||||||
@@ -1105,6 +1105,10 @@ def cached_path(
|
|||||||
if isinstance(cache_dir, Path):
|
if isinstance(cache_dir, Path):
|
||||||
cache_dir = str(cache_dir)
|
cache_dir = str(cache_dir)
|
||||||
|
|
||||||
|
if is_offline_mode() and not local_files_only:
|
||||||
|
logger.info("Offline mode: forcing local_files_only=True")
|
||||||
|
local_files_only = True
|
||||||
|
|
||||||
if is_remote_url(url_or_filename):
|
if is_remote_url(url_or_filename):
|
||||||
# URL, so get it from the cache (downloading if necessary)
|
# URL, so get it from the cache (downloading if necessary)
|
||||||
output_path = get_from_cache(
|
output_path = get_from_cache(
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ from flax.traverse_util import flatten_dict, unflatten_dict
|
|||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .file_utils import FLAX_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
|
from .file_utils import FLAX_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_offline_mode, is_remote_url
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -229,6 +229,10 @@ class FlaxPreTrainedModel(ABC):
|
|||||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
revision = kwargs.pop("revision", None)
|
revision = kwargs.pop("revision", None)
|
||||||
|
|
||||||
|
if is_offline_mode() and not local_files_only:
|
||||||
|
logger.info("Offline mode: forcing local_files_only=True")
|
||||||
|
local_files_only = True
|
||||||
|
|
||||||
# Load config if we don't provide a configuration
|
# Load config if we don't provide a configuration
|
||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
config_path = config if config is not None else pretrained_model_name_or_path
|
config_path = config if config is not None else pretrained_model_name_or_path
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ from .file_utils import (
|
|||||||
ModelOutput,
|
ModelOutput,
|
||||||
cached_path,
|
cached_path,
|
||||||
hf_bucket_url,
|
hf_bucket_url,
|
||||||
|
is_offline_mode,
|
||||||
is_remote_url,
|
is_remote_url,
|
||||||
)
|
)
|
||||||
from .generation_tf_utils import TFGenerationMixin
|
from .generation_tf_utils import TFGenerationMixin
|
||||||
@@ -1151,6 +1152,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
|||||||
revision = kwargs.pop("revision", None)
|
revision = kwargs.pop("revision", None)
|
||||||
mirror = kwargs.pop("mirror", None)
|
mirror = kwargs.pop("mirror", None)
|
||||||
|
|
||||||
|
if is_offline_mode() and not local_files_only:
|
||||||
|
logger.info("Offline mode: forcing local_files_only=True")
|
||||||
|
local_files_only = True
|
||||||
|
|
||||||
# Load config if we don't provide a configuration
|
# Load config if we don't provide a configuration
|
||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
config_path = config if config is not None else pretrained_model_name_or_path
|
config_path = config if config is not None else pretrained_model_name_or_path
|
||||||
|
|||||||
@@ -27,20 +27,37 @@ class OfflineTests(TestCasePlus):
|
|||||||
# while running an external program
|
# while running an external program
|
||||||
|
|
||||||
# python one-liner segments
|
# python one-liner segments
|
||||||
load = "from transformers import BertConfig, BertModel, BertTokenizer;"
|
|
||||||
run = "mname = 'lysandre/tiny-bert-random'; BertConfig.from_pretrained(mname) and BertModel.from_pretrained(mname) and BertTokenizer.from_pretrained(mname);"
|
# this must be loaded before socket.socket is monkey-patched
|
||||||
mock = 'import socket; exec("def offline_socket(*args, **kwargs): raise socket.error(\\"Offline mode is enabled.\\")"); socket.socket = offline_socket;'
|
load = """
|
||||||
|
from transformers import BertConfig, BertModel, BertTokenizer
|
||||||
|
"""
|
||||||
|
|
||||||
|
run = """
|
||||||
|
mname = "lysandre/tiny-bert-random"
|
||||||
|
BertConfig.from_pretrained(mname)
|
||||||
|
BertModel.from_pretrained(mname)
|
||||||
|
BertTokenizer.from_pretrained(mname)
|
||||||
|
print("success")
|
||||||
|
"""
|
||||||
|
|
||||||
|
mock = """
|
||||||
|
import socket
|
||||||
|
def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled")
|
||||||
|
socket.socket = offline_socket
|
||||||
|
"""
|
||||||
|
|
||||||
# baseline - just load from_pretrained with normal network
|
# baseline - just load from_pretrained with normal network
|
||||||
cmd = [sys.executable, "-c", f"{load} {run}"]
|
cmd = [sys.executable, "-c", "\n".join([load, run])]
|
||||||
|
|
||||||
# should succeed
|
# should succeed
|
||||||
env = self.get_env()
|
env = self.get_env()
|
||||||
result = subprocess.run(cmd, env=env, check=False, capture_output=True)
|
result = subprocess.run(cmd, env=env, check=False, capture_output=True)
|
||||||
self.assertEqual(result.returncode, 0, result.stderr)
|
self.assertEqual(result.returncode, 0, result.stderr)
|
||||||
|
self.assertIn("success", result.stdout.decode())
|
||||||
|
|
||||||
# next emulate no network
|
# next emulate no network
|
||||||
cmd = [sys.executable, "-c", f"{load} {mock} {run}"]
|
cmd = [sys.executable, "-c", "\n".join([load, mock, run])]
|
||||||
|
|
||||||
# should normally fail as it will fail to lookup the model files w/o the network
|
# should normally fail as it will fail to lookup the model files w/o the network
|
||||||
env["TRANSFORMERS_OFFLINE"] = "0"
|
env["TRANSFORMERS_OFFLINE"] = "0"
|
||||||
@@ -51,3 +68,4 @@ class OfflineTests(TestCasePlus):
|
|||||||
env["TRANSFORMERS_OFFLINE"] = "1"
|
env["TRANSFORMERS_OFFLINE"] = "1"
|
||||||
result = subprocess.run(cmd, env=env, check=False, capture_output=True)
|
result = subprocess.run(cmd, env=env, check=False, capture_output=True)
|
||||||
self.assertEqual(result.returncode, 0, result.stderr)
|
self.assertEqual(result.returncode, 0, result.stderr)
|
||||||
|
self.assertIn("success", result.stdout.decode())
|
||||||
|
|||||||
Reference in New Issue
Block a user