From 6f84531e6175700c821da3523c720d2698455e7f Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 8 Mar 2021 08:52:20 -0800 Subject: [PATCH] offline mode for firewalled envs (part 2) (#10569) * more readable test * add all the missing places * one more nltk * better exception check * revert --- src/transformers/feature_extraction_utils.py | 5 ++++ src/transformers/file_utils.py | 4 +++ src/transformers/modeling_flax_utils.py | 6 ++++- src/transformers/modeling_tf_utils.py | 5 ++++ tests/test_offline.py | 28 ++++++++++++++++---- 5 files changed, 42 insertions(+), 6 deletions(-) diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index 250a144313..3e07c4bcc8 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -35,6 +35,7 @@ from .file_utils import ( cached_path, hf_bucket_url, is_flax_available, + is_offline_mode, is_remote_url, is_tf_available, is_torch_available, @@ -342,6 +343,10 @@ class PreTrainedFeatureExtractor: local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + pretrained_model_name_or_path = str(pretrained_model_name_or_path) if os.path.isdir(pretrained_model_name_or_path): feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index e6309caaa7..c4183fa8f0 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -1105,6 +1105,10 @@ def cached_path( if isinstance(cache_dir, Path): cache_dir = str(cache_dir) + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + if is_remote_url(url_or_filename): # URL, so get it from the cache (downloading if necessary) output_path = get_from_cache( diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 4a3b5a95b3..8b245f6546 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -28,7 +28,7 @@ from flax.traverse_util import flatten_dict, unflatten_dict from jax.random import PRNGKey from .configuration_utils import PretrainedConfig -from .file_utils import FLAX_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url +from .file_utils import FLAX_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_offline_mode, is_remote_url from .utils import logging @@ -229,6 +229,10 @@ class FlaxPreTrainedModel(ABC): use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 38160fa542..720a052593 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -36,6 +36,7 @@ from .file_utils import ( ModelOutput, cached_path, hf_bucket_url, + is_offline_mode, is_remote_url, ) from .generation_tf_utils import TFGenerationMixin @@ -1151,6 +1152,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): revision = kwargs.pop("revision", None) mirror = kwargs.pop("mirror", None) + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path diff --git a/tests/test_offline.py b/tests/test_offline.py index 5217c5d6af..45a12a1f2b 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -27,20 +27,37 @@ class OfflineTests(TestCasePlus): # while running an external program # python one-liner segments - load = "from transformers import BertConfig, BertModel, BertTokenizer;" - run = "mname = 'lysandre/tiny-bert-random'; BertConfig.from_pretrained(mname) and BertModel.from_pretrained(mname) and BertTokenizer.from_pretrained(mname);" - mock = 'import socket; exec("def offline_socket(*args, **kwargs): raise socket.error(\\"Offline mode is enabled.\\")"); socket.socket = offline_socket;' + + # this must be loaded before socket.socket is monkey-patched + load = """ +from transformers import BertConfig, BertModel, BertTokenizer + """ + + run = """ +mname = "lysandre/tiny-bert-random" +BertConfig.from_pretrained(mname) +BertModel.from_pretrained(mname) +BertTokenizer.from_pretrained(mname) +print("success") + """ + + mock = """ +import socket +def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled") +socket.socket = offline_socket + """ # baseline - just load from_pretrained with normal network - cmd = [sys.executable, "-c", f"{load} {run}"] + cmd = [sys.executable, "-c", "\n".join([load, run])] # should succeed env = self.get_env() result = subprocess.run(cmd, env=env, check=False, capture_output=True) self.assertEqual(result.returncode, 0, result.stderr) + self.assertIn("success", result.stdout.decode()) # next emulate no network - cmd = [sys.executable, "-c", f"{load} {mock} {run}"] + cmd = [sys.executable, "-c", "\n".join([load, mock, run])] # should normally fail as it will fail to lookup the model files w/o the network env["TRANSFORMERS_OFFLINE"] = "0" @@ -51,3 +68,4 @@ class OfflineTests(TestCasePlus): env["TRANSFORMERS_OFFLINE"] = "1" result = subprocess.run(cmd, env=env, check=False, capture_output=True) self.assertEqual(result.returncode, 0, result.stderr) + self.assertIn("success", result.stdout.decode())