diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py index 046980ad13..62a124f7d3 100644 --- a/src/transformers/dynamic_module_utils.py +++ b/src/transformers/dynamic_module_utils.py @@ -22,9 +22,14 @@ import sys from pathlib import Path from typing import Dict, Optional, Union -from huggingface_hub import model_info - -from .utils import HF_MODULES_CACHE, TRANSFORMERS_DYNAMIC_MODULE_NAME, cached_file, is_offline_mode, logging +from .utils import ( + HF_MODULES_CACHE, + TRANSFORMERS_DYNAMIC_MODULE_NAME, + cached_file, + extract_commit_hash, + is_offline_mode, + logging, +) logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -163,6 +168,7 @@ def get_cached_module_file( use_auth_token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, local_files_only: bool = False, + _commit_hash: Optional[str] = None, ): """ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached @@ -233,6 +239,7 @@ def get_cached_module_file( local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, + _commit_hash=_commit_hash, ) except EnvironmentError: @@ -264,8 +271,7 @@ def get_cached_module_file( importlib.invalidate_caches() else: # Get the commit hash - # TODO: we will get this info in the etag soon, so retrieve it from there and not here. - commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=use_auth_token).sha + commit_hash = extract_commit_hash(resolved_module_file, _commit_hash) # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the # benefit of versioning. diff --git a/tests/utils/test_offline.py b/tests/utils/test_offline.py index 1fda8804a8..ad59747386 100644 --- a/tests/utils/test_offline.py +++ b/tests/utils/test_offline.py @@ -177,3 +177,29 @@ socket.socket = offline_socket self.assertIn( "You cannot infer task automatically within `pipeline` when using offline mode", result.stderr.decode() ) + + @require_torch + def test_offline_model_dynamic_model(self): + load = """ +from transformers import AutoModel + """ + run = """ +mname = "hf-internal-testing/test_dynamic_model" +AutoModel.from_pretrained(mname, trust_remote_code=True) +print("success") + """ + + # baseline - just load from_pretrained with normal network + cmd = [sys.executable, "-c", "\n".join([load, run])] + + # should succeed + env = self.get_env() + result = subprocess.run(cmd, env=env, check=False, capture_output=True) + self.assertEqual(result.returncode, 0, result.stderr) + self.assertIn("success", result.stdout.decode()) + + # should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files + env["TRANSFORMERS_OFFLINE"] = "1" + result = subprocess.run(cmd, env=env, check=False, capture_output=True) + self.assertEqual(result.returncode, 0, result.stderr) + self.assertIn("success", result.stdout.decode())