From e1c2b69c341f088d59949d1eabfc022c12d4c797 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 6 Sep 2024 19:49:35 +0800 Subject: [PATCH] Load dynamic module (remote code) only once if code isn't change (#33162) * Load remote code only once * Use hash as load indicator * Add a new option `force_reload` for old behavior (i.e. always reload) * Add test for dynamic module is cached * Add more type annotations to improve code readability * Address comments from code review --- src/transformers/dynamic_module_utils.py | 51 ++++++++++++++----- tests/models/auto/test_configuration_auto.py | 15 ++++++ .../auto/test_feature_extraction_auto.py | 17 +++++++ .../models/auto/test_image_processing_auto.py | 17 +++++++ tests/models/auto/test_modeling_auto.py | 32 ++++++++++++ tests/models/auto/test_tokenization_auto.py | 19 +++++++ 6 files changed, 139 insertions(+), 12 deletions(-) diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py index 08b6701302..07cb5940dc 100644 --- a/src/transformers/dynamic_module_utils.py +++ b/src/transformers/dynamic_module_utils.py @@ -15,6 +15,7 @@ """Utilities to dynamically load objects from the Hub.""" import filecmp +import hashlib import importlib import importlib.util import os @@ -22,9 +23,11 @@ import re import shutil import signal import sys +import threading import typing import warnings from pathlib import Path +from types import ModuleType from typing import Any, Dict, List, Optional, Union from huggingface_hub import try_to_load_from_cache @@ -40,6 +43,7 @@ from .utils import ( logger = logging.get_logger(__name__) # pylint: disable=invalid-name +_HF_REMOTE_CODE_LOCK = threading.Lock() def init_hf_modules(): @@ -58,7 +62,7 @@ def init_hf_modules(): importlib.invalidate_caches() -def create_dynamic_module(name: Union[str, os.PathLike]): +def create_dynamic_module(name: Union[str, os.PathLike]) -> None: """ Creates a dynamic module in the cache directory for modules. @@ -191,13 +195,21 @@ def check_imports(filename: Union[str, os.PathLike]) -> List[str]: return get_relative_imports(filename) -def get_class_in_module(class_name: str, module_path: Union[str, os.PathLike]) -> typing.Type: +def get_class_in_module( + class_name: str, + module_path: Union[str, os.PathLike], + *, + force_reload: bool = False, +) -> typing.Type: """ Import a module on the cache directory for modules and extract a class from it. Args: class_name (`str`): The name of the class to import. module_path (`str` or `os.PathLike`): The path to the module to import. + force_reload (`bool`, *optional*, defaults to `False`): + Whether to reload the dynamic module from file if it already exists in `sys.modules`. + Otherwise, the module is only reloaded if the file has changed. Returns: `typing.Type`: The class looked for. @@ -206,15 +218,30 @@ def get_class_in_module(class_name: str, module_path: Union[str, os.PathLike]) - if name.endswith(".py"): name = name[:-3] name = name.replace(os.path.sep, ".") - module_spec = importlib.util.spec_from_file_location(name, location=Path(HF_MODULES_CACHE) / module_path) - module = sys.modules.get(name) - if module is None: - module = importlib.util.module_from_spec(module_spec) - # insert it into sys.modules before any loading begins - sys.modules[name] = module - # reload in both cases - module_spec.loader.exec_module(module) - return getattr(module, class_name) + module_file: Path = Path(HF_MODULES_CACHE) / module_path + with _HF_REMOTE_CODE_LOCK: + if force_reload: + sys.modules.pop(name, None) + importlib.invalidate_caches() + cached_module: Optional[ModuleType] = sys.modules.get(name) + module_spec = importlib.util.spec_from_file_location(name, location=module_file) + + # Hash the module file and all its relative imports to check if we need to reload it + module_files: List[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file))) + module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest() + + module: ModuleType + if cached_module is None: + module = importlib.util.module_from_spec(module_spec) + # insert it into sys.modules before any loading begins + sys.modules[name] = module + else: + module = cached_module + # reload in both cases, unless the module is already imported and the hash hits + if getattr(module, "__transformers_module_hash__", "") != module_hash: + module_spec.loader.exec_module(module) + module.__transformers_module_hash__ = module_hash + return getattr(module, class_name) def get_cached_module_file( @@ -515,7 +542,7 @@ def get_class_from_dynamic_module( local_files_only=local_files_only, repo_type=repo_type, ) - return get_class_in_module(class_name, final_module) + return get_class_in_module(class_name, final_module, force_reload=force_download) def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[Dict] = None) -> List[str]: diff --git a/tests/models/auto/test_configuration_auto.py b/tests/models/auto/test_configuration_auto.py index 8b202b9092..c208985ef6 100644 --- a/tests/models/auto/test_configuration_auto.py +++ b/tests/models/auto/test_configuration_auto.py @@ -122,12 +122,27 @@ class AutoConfigTest(unittest.TestCase): config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True) self.assertEqual(config.__class__.__name__, "NewModelConfig") + # Test the dynamic module is loaded only once. + reloaded_config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True) + self.assertIs(config.__class__, reloaded_config.__class__) + # Test config can be reloaded. with tempfile.TemporaryDirectory() as tmp_dir: config.save_pretrained(tmp_dir) reloaded_config = AutoConfig.from_pretrained(tmp_dir, trust_remote_code=True) self.assertEqual(reloaded_config.__class__.__name__, "NewModelConfig") + # The configuration file is cached in the snapshot directory. So the module file is not changed after dumping + # to a temp dir. Because the revision of the configuration file is not changed. + # Test the dynamic module is loaded only once if the configuration file is not changed. + self.assertIs(config.__class__, reloaded_config.__class__) + + # Test the dynamic module is reloaded if we force it. + reloaded_config = AutoConfig.from_pretrained( + "hf-internal-testing/test_dynamic_model", trust_remote_code=True, force_download=True + ) + self.assertIsNot(config.__class__, reloaded_config.__class__) + def test_from_pretrained_dynamic_config_conflict(self): class NewModelConfigLocal(BertConfig): model_type = "new-model" diff --git a/tests/models/auto/test_feature_extraction_auto.py b/tests/models/auto/test_feature_extraction_auto.py index ed50006741..d36183a63c 100644 --- a/tests/models/auto/test_feature_extraction_auto.py +++ b/tests/models/auto/test_feature_extraction_auto.py @@ -116,12 +116,29 @@ class AutoFeatureExtractorTest(unittest.TestCase): ) self.assertEqual(feature_extractor.__class__.__name__, "NewFeatureExtractor") + # Test the dynamic module is loaded only once. + reloaded_feature_extractor = AutoFeatureExtractor.from_pretrained( + "hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=True + ) + self.assertIs(feature_extractor.__class__, reloaded_feature_extractor.__class__) + # Test feature extractor can be reloaded. with tempfile.TemporaryDirectory() as tmp_dir: feature_extractor.save_pretrained(tmp_dir) reloaded_feature_extractor = AutoFeatureExtractor.from_pretrained(tmp_dir, trust_remote_code=True) self.assertEqual(reloaded_feature_extractor.__class__.__name__, "NewFeatureExtractor") + # The feature extractor file is cached in the snapshot directory. So the module file is not changed after dumping + # to a temp dir. Because the revision of the module file is not changed. + # Test the dynamic module is loaded only once if the module file is not changed. + self.assertIs(feature_extractor.__class__, reloaded_feature_extractor.__class__) + + # Test the dynamic module is reloaded if we force it. + reloaded_feature_extractor = AutoFeatureExtractor.from_pretrained( + "hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=True, force_download=True + ) + self.assertIsNot(feature_extractor.__class__, reloaded_feature_extractor.__class__) + def test_new_feature_extractor_registration(self): try: AutoConfig.register("custom", CustomConfig) diff --git a/tests/models/auto/test_image_processing_auto.py b/tests/models/auto/test_image_processing_auto.py index b571e7a860..c0046ae1c3 100644 --- a/tests/models/auto/test_image_processing_auto.py +++ b/tests/models/auto/test_image_processing_auto.py @@ -167,12 +167,29 @@ class AutoImageProcessorTest(unittest.TestCase): ) self.assertEqual(image_processor.__class__.__name__, "NewImageProcessor") + # Test the dynamic module is loaded only once. + reloaded_image_processor = AutoImageProcessor.from_pretrained( + "hf-internal-testing/test_dynamic_image_processor", trust_remote_code=True + ) + self.assertIs(image_processor.__class__, reloaded_image_processor.__class__) + # Test image processor can be reloaded. with tempfile.TemporaryDirectory() as tmp_dir: image_processor.save_pretrained(tmp_dir) reloaded_image_processor = AutoImageProcessor.from_pretrained(tmp_dir, trust_remote_code=True) self.assertEqual(reloaded_image_processor.__class__.__name__, "NewImageProcessor") + # The image processor file is cached in the snapshot directory. So the module file is not changed after dumping + # to a temp dir. Because the revision of the module file is not changed. + # Test the dynamic module is loaded only once if the module file is not changed. + self.assertIs(image_processor.__class__, reloaded_image_processor.__class__) + + # Test the dynamic module is reloaded if we force it. + reloaded_image_processor = AutoImageProcessor.from_pretrained( + "hf-internal-testing/test_dynamic_image_processor", trust_remote_code=True, force_download=True + ) + self.assertIsNot(image_processor.__class__, reloaded_image_processor.__class__) + def test_new_image_processor_registration(self): try: AutoConfig.register("custom", CustomConfig) diff --git a/tests/models/auto/test_modeling_auto.py b/tests/models/auto/test_modeling_auto.py index 61085b9b5d..39770b091b 100644 --- a/tests/models/auto/test_modeling_auto.py +++ b/tests/models/auto/test_modeling_auto.py @@ -319,6 +319,10 @@ class AutoModelTest(unittest.TestCase): model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True) self.assertEqual(model.__class__.__name__, "NewModel") + # Test the dynamic module is loaded only once. + reloaded_model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True) + self.assertIs(model.__class__, reloaded_model.__class__) + # Test model can be reloaded. with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained(tmp_dir) @@ -328,10 +332,27 @@ class AutoModelTest(unittest.TestCase): for p1, p2 in zip(model.parameters(), reloaded_model.parameters()): self.assertTrue(torch.equal(p1, p2)) + # The model file is cached in the snapshot directory. So the module file is not changed after dumping + # to a temp dir. Because the revision of the module file is not changed. + # Test the dynamic module is loaded only once if the module file is not changed. + self.assertIs(model.__class__, reloaded_model.__class__) + + # Test the dynamic module is reloaded if we force it. + reloaded_model = AutoModel.from_pretrained( + "hf-internal-testing/test_dynamic_model", trust_remote_code=True, force_download=True + ) + self.assertIsNot(model.__class__, reloaded_model.__class__) + # This one uses a relative import to a util file, this checks it is downloaded and used properly. model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model_with_util", trust_remote_code=True) self.assertEqual(model.__class__.__name__, "NewModel") + # Test the dynamic module is loaded only once. + reloaded_model = AutoModel.from_pretrained( + "hf-internal-testing/test_dynamic_model_with_util", trust_remote_code=True + ) + self.assertIs(model.__class__, reloaded_model.__class__) + # Test model can be reloaded. with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained(tmp_dir) @@ -341,6 +362,17 @@ class AutoModelTest(unittest.TestCase): for p1, p2 in zip(model.parameters(), reloaded_model.parameters()): self.assertTrue(torch.equal(p1, p2)) + # The model file is cached in the snapshot directory. So the module file is not changed after dumping + # to a temp dir. Because the revision of the module file is not changed. + # Test the dynamic module is loaded only once if the module file is not changed. + self.assertIs(model.__class__, reloaded_model.__class__) + + # Test the dynamic module is reloaded if we force it. + reloaded_model = AutoModel.from_pretrained( + "hf-internal-testing/test_dynamic_model_with_util", trust_remote_code=True, force_download=True + ) + self.assertIsNot(model.__class__, reloaded_model.__class__) + def test_from_pretrained_dynamic_model_distant_with_ref(self): model = AutoModel.from_pretrained("hf-internal-testing/ref_to_test_dynamic_model", trust_remote_code=True) self.assertEqual(model.__class__.__name__, "NewModel") diff --git a/tests/models/auto/test_tokenization_auto.py b/tests/models/auto/test_tokenization_auto.py index ad96064308..f49ece15ff 100644 --- a/tests/models/auto/test_tokenization_auto.py +++ b/tests/models/auto/test_tokenization_auto.py @@ -314,6 +314,13 @@ class AutoTokenizerTest(unittest.TestCase): tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True) self.assertTrue(tokenizer.special_attribute_present) + + # Test the dynamic module is loaded only once. + reloaded_tokenizer = AutoTokenizer.from_pretrained( + "hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True + ) + self.assertIs(tokenizer.__class__, reloaded_tokenizer.__class__) + # Test tokenizer can be reloaded. with tempfile.TemporaryDirectory() as tmp_dir: tokenizer.save_pretrained(tmp_dir) @@ -340,6 +347,18 @@ class AutoTokenizerTest(unittest.TestCase): self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer") self.assertEqual(reloaded_tokenizer.__class__.__name__, "NewTokenizer") + # The tokenizer file is cached in the snapshot directory. So the module file is not changed after dumping + # to a temp dir. Because the revision of the module file is not changed. + # Test the dynamic module is loaded only once if the module file is not changed. + self.assertIs(tokenizer.__class__, reloaded_tokenizer.__class__) + + # Test the dynamic module is reloaded if we force it. + reloaded_tokenizer = AutoTokenizer.from_pretrained( + "hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True, force_download=True + ) + self.assertIsNot(tokenizer.__class__, reloaded_tokenizer.__class__) + self.assertTrue(reloaded_tokenizer.special_attribute_present) + @require_tokenizers def test_from_pretrained_dynamic_tokenizer_conflict(self): class NewTokenizer(BertTokenizer):