From 0b072304099e14f86dc62b59ca84f7eb5676af4e Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 27 Jan 2022 14:47:59 -0500 Subject: [PATCH] Allow relative imports in dynamic code (#15352) * Allow dynamic modules to use relative imports * Add tests * Add one last test * Changes --- src/transformers/models/auto/dynamic.py | 215 ++++++++++++++++++------ tests/test_configuration_auto.py | 4 + tests/test_modeling_auto.py | 10 +- 3 files changed, 174 insertions(+), 55 deletions(-) diff --git a/src/transformers/models/auto/dynamic.py b/src/transformers/models/auto/dynamic.py index 20b968ed6a..27fe385b7f 100644 --- a/src/transformers/models/auto/dynamic.py +++ b/src/transformers/models/auto/dynamic.py @@ -22,6 +22,8 @@ import sys from pathlib import Path from typing import Dict, Optional, Union +from huggingface_hub import HfFolder, model_info + from ...file_utils import ( HF_MODULES_CACHE, TRANSFORMERS_DYNAMIC_MODULE_NAME, @@ -79,6 +81,12 @@ def check_imports(filename): # Only keep the top-level module imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] + # Imports of the form `import .xxx` + relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE) + # Imports of the form `from .xxx import yyy` + relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE) + relative_imports = list(set(relative_imports)) + # Unique-ify and test we got them all imports = list(set(imports)) missing_packages = [] @@ -94,6 +102,8 @@ def check_imports(filename): f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`" ) + return relative_imports + def get_class_in_module(class_name, module_path): """ @@ -104,6 +114,145 @@ def get_class_in_module(class_name, module_path): return getattr(module, class_name) +def get_cached_module_file( + pretrained_model_name_or_path: Union[str, os.PathLike], + module_file: str, + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, +): + """ + Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached + Transformers module. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced + under a user or organization name, like `dbmdz/bert-base-german-cased`. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + module_file (`str`): + The name of the module file containing the class to look for. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision(`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + + + + Passing `use_auth_token=True` is required when you want to use a private model. + + + + Returns: + `str`: The path to the module inside the cache.""" + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file) + submodule = "local" + else: + module_file_or_url = hf_bucket_url( + pretrained_model_name_or_path, filename=module_file, revision=revision, mirror=None + ) + submodule = pretrained_model_name_or_path.replace("/", os.path.sep) + + try: + # Load from URL or cache if already cached + resolved_module_file = cached_path( + module_file_or_url, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + ) + + except EnvironmentError: + logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") + raise + + # Check we have all the requirements in our environment + modules_needed = check_imports(resolved_module_file) + + # Now we move the module inside our cached dynamic modules. + full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule + create_dynamic_module(full_submodule) + submodule_path = Path(HF_MODULES_CACHE) / full_submodule + if submodule == "local": + # We always copy local files (we could hash the file to see if there was a change, and give them the name of + # that hash, to only copy when there is a modification but it seems overkill for now). + # The only reason we do the copy is to avoid putting too many folders in sys.path. + shutil.copy(resolved_module_file, submodule_path / module_file) + for module_needed in modules_needed: + module_needed = f"{module_needed}.py" + shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed) + else: + # Get the commit hash + # TODO: we will get this info in the etag soon, so retrieve it from there. + if isinstance(use_auth_token, str): + token = use_auth_token + elif use_auth_token is True: + token = HfFolder.get_token() + else: + token = None + + commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha + + # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the + # benefit of versioning. + submodule_path = submodule_path / commit_hash + full_submodule = full_submodule + os.path.sep + commit_hash + create_dynamic_module(full_submodule) + + if not (submodule_path / module_file).exists(): + shutil.copy(resolved_module_file, submodule_path / module_file) + # Make sure we also have every file with relative + for module_needed in modules_needed: + if not (submodule_path / module_needed).exists(): + get_cached_module_file( + pretrained_model_name_or_path, + f"{module_needed}.py", + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + use_auth_token=use_auth_token, + revision=revision, + local_files_only=local_files_only, + ) + return os.path.join(full_submodule, module_file) + + def get_class_from_dynamic_module( pretrained_model_name_or_path: Union[str, os.PathLike], module_file: str, @@ -178,58 +327,16 @@ def get_class_from_dynamic_module( # module. cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel") ```""" - if is_offline_mode() and not local_files_only: - logger.info("Offline mode: forcing local_files_only=True") - local_files_only = True - - # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if os.path.isdir(pretrained_model_name_or_path): - module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file) - submodule = "local" - else: - module_file_or_url = hf_bucket_url( - pretrained_model_name_or_path, filename=module_file, revision=revision, mirror=None - ) - submodule = pretrained_model_name_or_path.replace("/", os.path.sep) - - try: - # Load from URL or cache if already cached - resolved_module_file = cached_path( - module_file_or_url, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - ) - - except EnvironmentError: - logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") - raise - - # Check we have all the requirements in our environment - check_imports(resolved_module_file) - - # Now we move the module inside our cached dynamic modules. - full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule - create_dynamic_module(full_submodule) - submodule_path = Path(HF_MODULES_CACHE) / full_submodule - if submodule == "local": - # We always copy local files (we could hash the file to see if there was a change, and give them the name of - # that hash, to only copy when there is a modification but it seems overkill for now). - # The only reason we do the copy is to avoid putting too many folders in sys.path. - module_name = module_file - shutil.copy(resolved_module_file, submodule_path / module_file) - else: - # The module file will end up being named module_file + the etag. This way we get the benefit of versioning. - resolved_module_file_name = Path(resolved_module_file).name - module_name_parts = [module_file.replace(".py", "")] + resolved_module_file_name.split(".") - module_name = "_".join(module_name_parts) + ".py" - if not (submodule_path / module_name).exists(): - shutil.copy(resolved_module_file, submodule_path / module_name) - # And lastly we get the class inside our newly created module - final_module = os.path.join(full_submodule, module_name.replace(".py", "")) - return get_class_in_module(class_name, final_module) + final_module = get_cached_module_file( + pretrained_model_name_or_path, + module_file, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + use_auth_token=use_auth_token, + revision=revision, + local_files_only=local_files_only, + ) + return get_class_in_module(class_name, final_module.replace(".py", "")) diff --git a/tests/test_configuration_auto.py b/tests/test_configuration_auto.py index a20be92853..6c8fdfb79f 100644 --- a/tests/test_configuration_auto.py +++ b/tests/test_configuration_auto.py @@ -102,3 +102,7 @@ class AutoConfigTest(unittest.TestCase): "hf-internal-testing/no-config-test-repo does not appear to have a file named config.json.", ): _ = AutoConfig.from_pretrained("hf-internal-testing/no-config-test-repo") + + def test_from_pretrained_dynamic_config(self): + config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True) + self.assertEqual(config.__class__.__name__, "NewModelConfig") diff --git a/tests/test_modeling_auto.py b/tests/test_modeling_auto.py index b732ff1b8d..4575d316ad 100644 --- a/tests/test_modeling_auto.py +++ b/tests/test_modeling_auto.py @@ -324,7 +324,7 @@ class AutoModelTest(unittest.TestCase): for child, parent in [(a, b) for a in child_model for b in parent_model]: assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}" - def test_from_pretrained_dynamic_model(self): + def test_from_pretrained_dynamic_model_local(self): config = BertConfig( vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 ) @@ -340,6 +340,14 @@ class AutoModelTest(unittest.TestCase): for p1, p2 in zip(model.parameters(), new_model.parameters()): self.assertTrue(torch.equal(p1, p2)) + def test_from_pretrained_dynamic_model_distant(self): + model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True) + self.assertEqual(model.__class__.__name__, "NewModel") + + # This one uses a relative import to a util file, this checks it is downloaded and used properly. + model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model_with_util", trust_remote_code=True) + self.assertEqual(model.__class__.__name__, "NewModel") + def test_new_model_registration(self): AutoConfig.register("new-model", NewModelConfig)