Load dynamic module (remote code) only once if code isn't change (#33162)
* Load remote code only once * Use hash as load indicator * Add a new option `force_reload` for old behavior (i.e. always reload) * Add test for dynamic module is cached * Add more type annotations to improve code readability * Address comments from code review
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
"""Utilities to dynamically load objects from the Hub."""
|
||||
|
||||
import filecmp
|
||||
import hashlib
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
@@ -22,9 +23,11 @@ import re
|
||||
import shutil
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import typing
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from huggingface_hub import try_to_load_from_cache
|
||||
@@ -40,6 +43,7 @@ from .utils import (
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
_HF_REMOTE_CODE_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def init_hf_modules():
|
||||
@@ -58,7 +62,7 @@ def init_hf_modules():
|
||||
importlib.invalidate_caches()
|
||||
|
||||
|
||||
def create_dynamic_module(name: Union[str, os.PathLike]):
|
||||
def create_dynamic_module(name: Union[str, os.PathLike]) -> None:
|
||||
"""
|
||||
Creates a dynamic module in the cache directory for modules.
|
||||
|
||||
@@ -191,13 +195,21 @@ def check_imports(filename: Union[str, os.PathLike]) -> List[str]:
|
||||
return get_relative_imports(filename)
|
||||
|
||||
|
||||
def get_class_in_module(class_name: str, module_path: Union[str, os.PathLike]) -> typing.Type:
|
||||
def get_class_in_module(
|
||||
class_name: str,
|
||||
module_path: Union[str, os.PathLike],
|
||||
*,
|
||||
force_reload: bool = False,
|
||||
) -> typing.Type:
|
||||
"""
|
||||
Import a module on the cache directory for modules and extract a class from it.
|
||||
|
||||
Args:
|
||||
class_name (`str`): The name of the class to import.
|
||||
module_path (`str` or `os.PathLike`): The path to the module to import.
|
||||
force_reload (`bool`, *optional*, defaults to `False`):
|
||||
Whether to reload the dynamic module from file if it already exists in `sys.modules`.
|
||||
Otherwise, the module is only reloaded if the file has changed.
|
||||
|
||||
Returns:
|
||||
`typing.Type`: The class looked for.
|
||||
@@ -206,15 +218,30 @@ def get_class_in_module(class_name: str, module_path: Union[str, os.PathLike]) -
|
||||
if name.endswith(".py"):
|
||||
name = name[:-3]
|
||||
name = name.replace(os.path.sep, ".")
|
||||
module_spec = importlib.util.spec_from_file_location(name, location=Path(HF_MODULES_CACHE) / module_path)
|
||||
module = sys.modules.get(name)
|
||||
if module is None:
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
# insert it into sys.modules before any loading begins
|
||||
sys.modules[name] = module
|
||||
# reload in both cases
|
||||
module_spec.loader.exec_module(module)
|
||||
return getattr(module, class_name)
|
||||
module_file: Path = Path(HF_MODULES_CACHE) / module_path
|
||||
with _HF_REMOTE_CODE_LOCK:
|
||||
if force_reload:
|
||||
sys.modules.pop(name, None)
|
||||
importlib.invalidate_caches()
|
||||
cached_module: Optional[ModuleType] = sys.modules.get(name)
|
||||
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
|
||||
|
||||
# Hash the module file and all its relative imports to check if we need to reload it
|
||||
module_files: List[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file)))
|
||||
module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest()
|
||||
|
||||
module: ModuleType
|
||||
if cached_module is None:
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
# insert it into sys.modules before any loading begins
|
||||
sys.modules[name] = module
|
||||
else:
|
||||
module = cached_module
|
||||
# reload in both cases, unless the module is already imported and the hash hits
|
||||
if getattr(module, "__transformers_module_hash__", "") != module_hash:
|
||||
module_spec.loader.exec_module(module)
|
||||
module.__transformers_module_hash__ = module_hash
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def get_cached_module_file(
|
||||
@@ -515,7 +542,7 @@ def get_class_from_dynamic_module(
|
||||
local_files_only=local_files_only,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
return get_class_in_module(class_name, final_module)
|
||||
return get_class_in_module(class_name, final_module, force_reload=force_download)
|
||||
|
||||
|
||||
def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[Dict] = None) -> List[str]:
|
||||
|
||||
@@ -122,12 +122,27 @@ class AutoConfigTest(unittest.TestCase):
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True)
|
||||
self.assertEqual(config.__class__.__name__, "NewModelConfig")
|
||||
|
||||
# Test the dynamic module is loaded only once.
|
||||
reloaded_config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True)
|
||||
self.assertIs(config.__class__, reloaded_config.__class__)
|
||||
|
||||
# Test config can be reloaded.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config.save_pretrained(tmp_dir)
|
||||
reloaded_config = AutoConfig.from_pretrained(tmp_dir, trust_remote_code=True)
|
||||
self.assertEqual(reloaded_config.__class__.__name__, "NewModelConfig")
|
||||
|
||||
# The configuration file is cached in the snapshot directory. So the module file is not changed after dumping
|
||||
# to a temp dir. Because the revision of the configuration file is not changed.
|
||||
# Test the dynamic module is loaded only once if the configuration file is not changed.
|
||||
self.assertIs(config.__class__, reloaded_config.__class__)
|
||||
|
||||
# Test the dynamic module is reloaded if we force it.
|
||||
reloaded_config = AutoConfig.from_pretrained(
|
||||
"hf-internal-testing/test_dynamic_model", trust_remote_code=True, force_download=True
|
||||
)
|
||||
self.assertIsNot(config.__class__, reloaded_config.__class__)
|
||||
|
||||
def test_from_pretrained_dynamic_config_conflict(self):
|
||||
class NewModelConfigLocal(BertConfig):
|
||||
model_type = "new-model"
|
||||
|
||||
@@ -116,12 +116,29 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(feature_extractor.__class__.__name__, "NewFeatureExtractor")
|
||||
|
||||
# Test the dynamic module is loaded only once.
|
||||
reloaded_feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
"hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=True
|
||||
)
|
||||
self.assertIs(feature_extractor.__class__, reloaded_feature_extractor.__class__)
|
||||
|
||||
# Test feature extractor can be reloaded.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
feature_extractor.save_pretrained(tmp_dir)
|
||||
reloaded_feature_extractor = AutoFeatureExtractor.from_pretrained(tmp_dir, trust_remote_code=True)
|
||||
self.assertEqual(reloaded_feature_extractor.__class__.__name__, "NewFeatureExtractor")
|
||||
|
||||
# The feature extractor file is cached in the snapshot directory. So the module file is not changed after dumping
|
||||
# to a temp dir. Because the revision of the module file is not changed.
|
||||
# Test the dynamic module is loaded only once if the module file is not changed.
|
||||
self.assertIs(feature_extractor.__class__, reloaded_feature_extractor.__class__)
|
||||
|
||||
# Test the dynamic module is reloaded if we force it.
|
||||
reloaded_feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
"hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=True, force_download=True
|
||||
)
|
||||
self.assertIsNot(feature_extractor.__class__, reloaded_feature_extractor.__class__)
|
||||
|
||||
def test_new_feature_extractor_registration(self):
|
||||
try:
|
||||
AutoConfig.register("custom", CustomConfig)
|
||||
|
||||
@@ -167,12 +167,29 @@ class AutoImageProcessorTest(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(image_processor.__class__.__name__, "NewImageProcessor")
|
||||
|
||||
# Test the dynamic module is loaded only once.
|
||||
reloaded_image_processor = AutoImageProcessor.from_pretrained(
|
||||
"hf-internal-testing/test_dynamic_image_processor", trust_remote_code=True
|
||||
)
|
||||
self.assertIs(image_processor.__class__, reloaded_image_processor.__class__)
|
||||
|
||||
# Test image processor can be reloaded.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
image_processor.save_pretrained(tmp_dir)
|
||||
reloaded_image_processor = AutoImageProcessor.from_pretrained(tmp_dir, trust_remote_code=True)
|
||||
self.assertEqual(reloaded_image_processor.__class__.__name__, "NewImageProcessor")
|
||||
|
||||
# The image processor file is cached in the snapshot directory. So the module file is not changed after dumping
|
||||
# to a temp dir. Because the revision of the module file is not changed.
|
||||
# Test the dynamic module is loaded only once if the module file is not changed.
|
||||
self.assertIs(image_processor.__class__, reloaded_image_processor.__class__)
|
||||
|
||||
# Test the dynamic module is reloaded if we force it.
|
||||
reloaded_image_processor = AutoImageProcessor.from_pretrained(
|
||||
"hf-internal-testing/test_dynamic_image_processor", trust_remote_code=True, force_download=True
|
||||
)
|
||||
self.assertIsNot(image_processor.__class__, reloaded_image_processor.__class__)
|
||||
|
||||
def test_new_image_processor_registration(self):
|
||||
try:
|
||||
AutoConfig.register("custom", CustomConfig)
|
||||
|
||||
@@ -319,6 +319,10 @@ class AutoModelTest(unittest.TestCase):
|
||||
model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True)
|
||||
self.assertEqual(model.__class__.__name__, "NewModel")
|
||||
|
||||
# Test the dynamic module is loaded only once.
|
||||
reloaded_model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True)
|
||||
self.assertIs(model.__class__, reloaded_model.__class__)
|
||||
|
||||
# Test model can be reloaded.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
@@ -328,10 +332,27 @@ class AutoModelTest(unittest.TestCase):
|
||||
for p1, p2 in zip(model.parameters(), reloaded_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
# The model file is cached in the snapshot directory. So the module file is not changed after dumping
|
||||
# to a temp dir. Because the revision of the module file is not changed.
|
||||
# Test the dynamic module is loaded only once if the module file is not changed.
|
||||
self.assertIs(model.__class__, reloaded_model.__class__)
|
||||
|
||||
# Test the dynamic module is reloaded if we force it.
|
||||
reloaded_model = AutoModel.from_pretrained(
|
||||
"hf-internal-testing/test_dynamic_model", trust_remote_code=True, force_download=True
|
||||
)
|
||||
self.assertIsNot(model.__class__, reloaded_model.__class__)
|
||||
|
||||
# This one uses a relative import to a util file, this checks it is downloaded and used properly.
|
||||
model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model_with_util", trust_remote_code=True)
|
||||
self.assertEqual(model.__class__.__name__, "NewModel")
|
||||
|
||||
# Test the dynamic module is loaded only once.
|
||||
reloaded_model = AutoModel.from_pretrained(
|
||||
"hf-internal-testing/test_dynamic_model_with_util", trust_remote_code=True
|
||||
)
|
||||
self.assertIs(model.__class__, reloaded_model.__class__)
|
||||
|
||||
# Test model can be reloaded.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
@@ -341,6 +362,17 @@ class AutoModelTest(unittest.TestCase):
|
||||
for p1, p2 in zip(model.parameters(), reloaded_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
# The model file is cached in the snapshot directory. So the module file is not changed after dumping
|
||||
# to a temp dir. Because the revision of the module file is not changed.
|
||||
# Test the dynamic module is loaded only once if the module file is not changed.
|
||||
self.assertIs(model.__class__, reloaded_model.__class__)
|
||||
|
||||
# Test the dynamic module is reloaded if we force it.
|
||||
reloaded_model = AutoModel.from_pretrained(
|
||||
"hf-internal-testing/test_dynamic_model_with_util", trust_remote_code=True, force_download=True
|
||||
)
|
||||
self.assertIsNot(model.__class__, reloaded_model.__class__)
|
||||
|
||||
def test_from_pretrained_dynamic_model_distant_with_ref(self):
|
||||
model = AutoModel.from_pretrained("hf-internal-testing/ref_to_test_dynamic_model", trust_remote_code=True)
|
||||
self.assertEqual(model.__class__.__name__, "NewModel")
|
||||
|
||||
@@ -314,6 +314,13 @@ class AutoTokenizerTest(unittest.TestCase):
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True)
|
||||
self.assertTrue(tokenizer.special_attribute_present)
|
||||
|
||||
# Test the dynamic module is loaded only once.
|
||||
reloaded_tokenizer = AutoTokenizer.from_pretrained(
|
||||
"hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True
|
||||
)
|
||||
self.assertIs(tokenizer.__class__, reloaded_tokenizer.__class__)
|
||||
|
||||
# Test tokenizer can be reloaded.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tokenizer.save_pretrained(tmp_dir)
|
||||
@@ -340,6 +347,18 @@ class AutoTokenizerTest(unittest.TestCase):
|
||||
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")
|
||||
self.assertEqual(reloaded_tokenizer.__class__.__name__, "NewTokenizer")
|
||||
|
||||
# The tokenizer file is cached in the snapshot directory. So the module file is not changed after dumping
|
||||
# to a temp dir. Because the revision of the module file is not changed.
|
||||
# Test the dynamic module is loaded only once if the module file is not changed.
|
||||
self.assertIs(tokenizer.__class__, reloaded_tokenizer.__class__)
|
||||
|
||||
# Test the dynamic module is reloaded if we force it.
|
||||
reloaded_tokenizer = AutoTokenizer.from_pretrained(
|
||||
"hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True, force_download=True
|
||||
)
|
||||
self.assertIsNot(tokenizer.__class__, reloaded_tokenizer.__class__)
|
||||
self.assertTrue(reloaded_tokenizer.special_attribute_present)
|
||||
|
||||
@require_tokenizers
|
||||
def test_from_pretrained_dynamic_tokenizer_conflict(self):
|
||||
class NewTokenizer(BertTokenizer):
|
||||
|
||||
Reference in New Issue
Block a user