Load dynamic module (remote code) only once if code isn't change (#33162)

* Load remote code only once

* Use hash as load indicator

* Add a new option `force_reload` for old behavior (i.e. always reload)

* Add test for dynamic module is cached

* Add more type annotations to improve code readability

* Address comments from code review
This commit is contained in:
Xuehai Pan
2024-09-06 19:49:35 +08:00
committed by GitHub
parent 1bd9d1c899
commit e1c2b69c34
6 changed files with 139 additions and 12 deletions

View File

@@ -15,6 +15,7 @@
"""Utilities to dynamically load objects from the Hub.""" """Utilities to dynamically load objects from the Hub."""
import filecmp import filecmp
import hashlib
import importlib import importlib
import importlib.util import importlib.util
import os import os
@@ -22,9 +23,11 @@ import re
import shutil import shutil
import signal import signal
import sys import sys
import threading
import typing import typing
import warnings import warnings
from pathlib import Path from pathlib import Path
from types import ModuleType
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from huggingface_hub import try_to_load_from_cache from huggingface_hub import try_to_load_from_cache
@@ -40,6 +43,7 @@ from .utils import (
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_HF_REMOTE_CODE_LOCK = threading.Lock()
def init_hf_modules(): def init_hf_modules():
@@ -58,7 +62,7 @@ def init_hf_modules():
importlib.invalidate_caches() importlib.invalidate_caches()
def create_dynamic_module(name: Union[str, os.PathLike]): def create_dynamic_module(name: Union[str, os.PathLike]) -> None:
""" """
Creates a dynamic module in the cache directory for modules. Creates a dynamic module in the cache directory for modules.
@@ -191,13 +195,21 @@ def check_imports(filename: Union[str, os.PathLike]) -> List[str]:
return get_relative_imports(filename) return get_relative_imports(filename)
def get_class_in_module(class_name: str, module_path: Union[str, os.PathLike]) -> typing.Type: def get_class_in_module(
class_name: str,
module_path: Union[str, os.PathLike],
*,
force_reload: bool = False,
) -> typing.Type:
""" """
Import a module on the cache directory for modules and extract a class from it. Import a module on the cache directory for modules and extract a class from it.
Args: Args:
class_name (`str`): The name of the class to import. class_name (`str`): The name of the class to import.
module_path (`str` or `os.PathLike`): The path to the module to import. module_path (`str` or `os.PathLike`): The path to the module to import.
force_reload (`bool`, *optional*, defaults to `False`):
Whether to reload the dynamic module from file if it already exists in `sys.modules`.
Otherwise, the module is only reloaded if the file has changed.
Returns: Returns:
`typing.Type`: The class looked for. `typing.Type`: The class looked for.
@@ -206,15 +218,30 @@ def get_class_in_module(class_name: str, module_path: Union[str, os.PathLike]) -
if name.endswith(".py"): if name.endswith(".py"):
name = name[:-3] name = name[:-3]
name = name.replace(os.path.sep, ".") name = name.replace(os.path.sep, ".")
module_spec = importlib.util.spec_from_file_location(name, location=Path(HF_MODULES_CACHE) / module_path) module_file: Path = Path(HF_MODULES_CACHE) / module_path
module = sys.modules.get(name) with _HF_REMOTE_CODE_LOCK:
if module is None: if force_reload:
module = importlib.util.module_from_spec(module_spec) sys.modules.pop(name, None)
# insert it into sys.modules before any loading begins importlib.invalidate_caches()
sys.modules[name] = module cached_module: Optional[ModuleType] = sys.modules.get(name)
# reload in both cases module_spec = importlib.util.spec_from_file_location(name, location=module_file)
module_spec.loader.exec_module(module)
return getattr(module, class_name) # Hash the module file and all its relative imports to check if we need to reload it
module_files: List[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file)))
module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest()
module: ModuleType
if cached_module is None:
module = importlib.util.module_from_spec(module_spec)
# insert it into sys.modules before any loading begins
sys.modules[name] = module
else:
module = cached_module
# reload in both cases, unless the module is already imported and the hash hits
if getattr(module, "__transformers_module_hash__", "") != module_hash:
module_spec.loader.exec_module(module)
module.__transformers_module_hash__ = module_hash
return getattr(module, class_name)
def get_cached_module_file( def get_cached_module_file(
@@ -515,7 +542,7 @@ def get_class_from_dynamic_module(
local_files_only=local_files_only, local_files_only=local_files_only,
repo_type=repo_type, repo_type=repo_type,
) )
return get_class_in_module(class_name, final_module) return get_class_in_module(class_name, final_module, force_reload=force_download)
def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[Dict] = None) -> List[str]: def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[Dict] = None) -> List[str]:

View File

@@ -122,12 +122,27 @@ class AutoConfigTest(unittest.TestCase):
config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True) config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True)
self.assertEqual(config.__class__.__name__, "NewModelConfig") self.assertEqual(config.__class__.__name__, "NewModelConfig")
# Test the dynamic module is loaded only once.
reloaded_config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True)
self.assertIs(config.__class__, reloaded_config.__class__)
# Test config can be reloaded. # Test config can be reloaded.
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir) config.save_pretrained(tmp_dir)
reloaded_config = AutoConfig.from_pretrained(tmp_dir, trust_remote_code=True) reloaded_config = AutoConfig.from_pretrained(tmp_dir, trust_remote_code=True)
self.assertEqual(reloaded_config.__class__.__name__, "NewModelConfig") self.assertEqual(reloaded_config.__class__.__name__, "NewModelConfig")
# The configuration file is cached in the snapshot directory. So the module file is not changed after dumping
# to a temp dir. Because the revision of the configuration file is not changed.
# Test the dynamic module is loaded only once if the configuration file is not changed.
self.assertIs(config.__class__, reloaded_config.__class__)
# Test the dynamic module is reloaded if we force it.
reloaded_config = AutoConfig.from_pretrained(
"hf-internal-testing/test_dynamic_model", trust_remote_code=True, force_download=True
)
self.assertIsNot(config.__class__, reloaded_config.__class__)
def test_from_pretrained_dynamic_config_conflict(self): def test_from_pretrained_dynamic_config_conflict(self):
class NewModelConfigLocal(BertConfig): class NewModelConfigLocal(BertConfig):
model_type = "new-model" model_type = "new-model"

View File

@@ -116,12 +116,29 @@ class AutoFeatureExtractorTest(unittest.TestCase):
) )
self.assertEqual(feature_extractor.__class__.__name__, "NewFeatureExtractor") self.assertEqual(feature_extractor.__class__.__name__, "NewFeatureExtractor")
# Test the dynamic module is loaded only once.
reloaded_feature_extractor = AutoFeatureExtractor.from_pretrained(
"hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=True
)
self.assertIs(feature_extractor.__class__, reloaded_feature_extractor.__class__)
# Test feature extractor can be reloaded. # Test feature extractor can be reloaded.
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
feature_extractor.save_pretrained(tmp_dir) feature_extractor.save_pretrained(tmp_dir)
reloaded_feature_extractor = AutoFeatureExtractor.from_pretrained(tmp_dir, trust_remote_code=True) reloaded_feature_extractor = AutoFeatureExtractor.from_pretrained(tmp_dir, trust_remote_code=True)
self.assertEqual(reloaded_feature_extractor.__class__.__name__, "NewFeatureExtractor") self.assertEqual(reloaded_feature_extractor.__class__.__name__, "NewFeatureExtractor")
# The feature extractor file is cached in the snapshot directory. So the module file is not changed after dumping
# to a temp dir. Because the revision of the module file is not changed.
# Test the dynamic module is loaded only once if the module file is not changed.
self.assertIs(feature_extractor.__class__, reloaded_feature_extractor.__class__)
# Test the dynamic module is reloaded if we force it.
reloaded_feature_extractor = AutoFeatureExtractor.from_pretrained(
"hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=True, force_download=True
)
self.assertIsNot(feature_extractor.__class__, reloaded_feature_extractor.__class__)
def test_new_feature_extractor_registration(self): def test_new_feature_extractor_registration(self):
try: try:
AutoConfig.register("custom", CustomConfig) AutoConfig.register("custom", CustomConfig)

View File

@@ -167,12 +167,29 @@ class AutoImageProcessorTest(unittest.TestCase):
) )
self.assertEqual(image_processor.__class__.__name__, "NewImageProcessor") self.assertEqual(image_processor.__class__.__name__, "NewImageProcessor")
# Test the dynamic module is loaded only once.
reloaded_image_processor = AutoImageProcessor.from_pretrained(
"hf-internal-testing/test_dynamic_image_processor", trust_remote_code=True
)
self.assertIs(image_processor.__class__, reloaded_image_processor.__class__)
# Test image processor can be reloaded. # Test image processor can be reloaded.
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
image_processor.save_pretrained(tmp_dir) image_processor.save_pretrained(tmp_dir)
reloaded_image_processor = AutoImageProcessor.from_pretrained(tmp_dir, trust_remote_code=True) reloaded_image_processor = AutoImageProcessor.from_pretrained(tmp_dir, trust_remote_code=True)
self.assertEqual(reloaded_image_processor.__class__.__name__, "NewImageProcessor") self.assertEqual(reloaded_image_processor.__class__.__name__, "NewImageProcessor")
# The image processor file is cached in the snapshot directory. So the module file is not changed after dumping
# to a temp dir. Because the revision of the module file is not changed.
# Test the dynamic module is loaded only once if the module file is not changed.
self.assertIs(image_processor.__class__, reloaded_image_processor.__class__)
# Test the dynamic module is reloaded if we force it.
reloaded_image_processor = AutoImageProcessor.from_pretrained(
"hf-internal-testing/test_dynamic_image_processor", trust_remote_code=True, force_download=True
)
self.assertIsNot(image_processor.__class__, reloaded_image_processor.__class__)
def test_new_image_processor_registration(self): def test_new_image_processor_registration(self):
try: try:
AutoConfig.register("custom", CustomConfig) AutoConfig.register("custom", CustomConfig)

View File

@@ -319,6 +319,10 @@ class AutoModelTest(unittest.TestCase):
model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True) model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True)
self.assertEqual(model.__class__.__name__, "NewModel") self.assertEqual(model.__class__.__name__, "NewModel")
# Test the dynamic module is loaded only once.
reloaded_model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True)
self.assertIs(model.__class__, reloaded_model.__class__)
# Test model can be reloaded. # Test model can be reloaded.
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir) model.save_pretrained(tmp_dir)
@@ -328,10 +332,27 @@ class AutoModelTest(unittest.TestCase):
for p1, p2 in zip(model.parameters(), reloaded_model.parameters()): for p1, p2 in zip(model.parameters(), reloaded_model.parameters()):
self.assertTrue(torch.equal(p1, p2)) self.assertTrue(torch.equal(p1, p2))
# The model file is cached in the snapshot directory. So the module file is not changed after dumping
# to a temp dir. Because the revision of the module file is not changed.
# Test the dynamic module is loaded only once if the module file is not changed.
self.assertIs(model.__class__, reloaded_model.__class__)
# Test the dynamic module is reloaded if we force it.
reloaded_model = AutoModel.from_pretrained(
"hf-internal-testing/test_dynamic_model", trust_remote_code=True, force_download=True
)
self.assertIsNot(model.__class__, reloaded_model.__class__)
# This one uses a relative import to a util file, this checks it is downloaded and used properly. # This one uses a relative import to a util file, this checks it is downloaded and used properly.
model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model_with_util", trust_remote_code=True) model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model_with_util", trust_remote_code=True)
self.assertEqual(model.__class__.__name__, "NewModel") self.assertEqual(model.__class__.__name__, "NewModel")
# Test the dynamic module is loaded only once.
reloaded_model = AutoModel.from_pretrained(
"hf-internal-testing/test_dynamic_model_with_util", trust_remote_code=True
)
self.assertIs(model.__class__, reloaded_model.__class__)
# Test model can be reloaded. # Test model can be reloaded.
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir) model.save_pretrained(tmp_dir)
@@ -341,6 +362,17 @@ class AutoModelTest(unittest.TestCase):
for p1, p2 in zip(model.parameters(), reloaded_model.parameters()): for p1, p2 in zip(model.parameters(), reloaded_model.parameters()):
self.assertTrue(torch.equal(p1, p2)) self.assertTrue(torch.equal(p1, p2))
# The model file is cached in the snapshot directory. So the module file is not changed after dumping
# to a temp dir. Because the revision of the module file is not changed.
# Test the dynamic module is loaded only once if the module file is not changed.
self.assertIs(model.__class__, reloaded_model.__class__)
# Test the dynamic module is reloaded if we force it.
reloaded_model = AutoModel.from_pretrained(
"hf-internal-testing/test_dynamic_model_with_util", trust_remote_code=True, force_download=True
)
self.assertIsNot(model.__class__, reloaded_model.__class__)
def test_from_pretrained_dynamic_model_distant_with_ref(self): def test_from_pretrained_dynamic_model_distant_with_ref(self):
model = AutoModel.from_pretrained("hf-internal-testing/ref_to_test_dynamic_model", trust_remote_code=True) model = AutoModel.from_pretrained("hf-internal-testing/ref_to_test_dynamic_model", trust_remote_code=True)
self.assertEqual(model.__class__.__name__, "NewModel") self.assertEqual(model.__class__.__name__, "NewModel")

View File

@@ -314,6 +314,13 @@ class AutoTokenizerTest(unittest.TestCase):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True)
self.assertTrue(tokenizer.special_attribute_present) self.assertTrue(tokenizer.special_attribute_present)
# Test the dynamic module is loaded only once.
reloaded_tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True
)
self.assertIs(tokenizer.__class__, reloaded_tokenizer.__class__)
# Test tokenizer can be reloaded. # Test tokenizer can be reloaded.
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer.save_pretrained(tmp_dir) tokenizer.save_pretrained(tmp_dir)
@@ -340,6 +347,18 @@ class AutoTokenizerTest(unittest.TestCase):
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer") self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")
self.assertEqual(reloaded_tokenizer.__class__.__name__, "NewTokenizer") self.assertEqual(reloaded_tokenizer.__class__.__name__, "NewTokenizer")
# The tokenizer file is cached in the snapshot directory. So the module file is not changed after dumping
# to a temp dir. Because the revision of the module file is not changed.
# Test the dynamic module is loaded only once if the module file is not changed.
self.assertIs(tokenizer.__class__, reloaded_tokenizer.__class__)
# Test the dynamic module is reloaded if we force it.
reloaded_tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True, force_download=True
)
self.assertIsNot(tokenizer.__class__, reloaded_tokenizer.__class__)
self.assertTrue(reloaded_tokenizer.special_attribute_present)
@require_tokenizers @require_tokenizers
def test_from_pretrained_dynamic_tokenizer_conflict(self): def test_from_pretrained_dynamic_tokenizer_conflict(self):
class NewTokenizer(BertTokenizer): class NewTokenizer(BertTokenizer):