From 297a6a7aea941445845481c4dca60dcbe9aa1b75 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 17 Aug 2023 08:28:06 +0200 Subject: [PATCH] Add documentation to dynamic module utils (#25534) * Add documentation to dynamic module utils * Address review comments --- src/transformers/dynamic_module_utils.py | 61 +++++++++++++++++++----- 1 file changed, 49 insertions(+), 12 deletions(-) diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py index d59257b273..031ab8b7d1 100644 --- a/src/transformers/dynamic_module_utils.py +++ b/src/transformers/dynamic_module_utils.py @@ -20,9 +20,10 @@ import re import shutil import signal import sys +import typing import warnings from pathlib import Path -from typing import Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union from .utils import ( HF_MODULES_CACHE, @@ -57,6 +58,10 @@ def init_hf_modules(): def create_dynamic_module(name: Union[str, os.PathLike]): """ Creates a dynamic module in the cache directory for modules. + + Args: + name (`str` or `os.PathLike`): + The name of the dynamic module to create. """ init_hf_modules() dynamic_module_path = (Path(HF_MODULES_CACHE) / name).resolve() @@ -67,15 +72,20 @@ def create_dynamic_module(name: Union[str, os.PathLike]): init_path = dynamic_module_path / "__init__.py" if not init_path.exists(): init_path.touch() + # It is extremely important to invalidate the cache when we change stuff in those modules, or users end up + # with errors about module that do not exist. Same for all other `invalidate_caches` in this file. importlib.invalidate_caches() -def get_relative_imports(module_file): +def get_relative_imports(module_file: Union[str, os.PathLike]) -> List[str]: """ Get the list of modules that are relatively imported in a module file. Args: module_file (`str` or `os.PathLike`): The module file to inspect. + + Returns: + `List[str]`: The list of relative imports in the module. """ with open(module_file, "r", encoding="utf-8") as f: content = f.read() @@ -88,13 +98,17 @@ def get_relative_imports(module_file): return list(set(relative_imports)) -def get_relative_import_files(module_file): +def get_relative_import_files(module_file: Union[str, os.PathLike]) -> List[str]: """ Get the list of all files that are needed for a given module. Note that this function recurses through the relative imports (if a imports b and b imports c, it will return module files for b and c). Args: module_file (`str` or `os.PathLike`): The module file to inspect. + + Returns: + `List[str]`: The list of all relative imports a given module needs (recursively), which will give us the list + of module files a given module needs. """ no_change = False files_to_check = [module_file] @@ -117,9 +131,15 @@ def get_relative_import_files(module_file): return all_relative_imports -def get_imports(filename): +def get_imports(filename: Union[str, os.PathLike]) -> List[str]: """ - Extracts all the libraries that are imported in a file. + Extracts all the libraries (not relative imports this time) that are imported in a file. + + Args: + filename (`str` or `os.PathLike`): The module file to inspect. + + Returns: + `List[str]`: The list of all packages required to use the input module. """ with open(filename, "r", encoding="utf-8") as f: content = f.read() @@ -136,9 +156,16 @@ def get_imports(filename): return list(set(imports)) -def check_imports(filename): +def check_imports(filename: Union[str, os.PathLike]) -> List[str]: """ - Check if the current Python environment contains all the libraries that are imported in a file. + Check if the current Python environment contains all the libraries that are imported in a file. Will raise if a + library is missing. + + Args: + filename (`str` or `os.PathLike`): The module file to check. + + Returns: + `List[str]`: The list of relative imports in the file. """ imports = get_imports(filename) missing_packages = [] @@ -157,9 +184,16 @@ def check_imports(filename): return get_relative_imports(filename) -def get_class_in_module(class_name, module_path): +def get_class_in_module(class_name: str, module_path: Union[str, os.PathLike]) -> typing.Type: """ Import a module on the cache directory for modules and extract a class from it. + + Args: + class_name (`str`): The name of the class to import. + module_path (`str` or `os.PathLike`): The path to the module to import. + + Returns: + `typing.Type`: The class looked for. """ module_path = module_path.replace(os.path.sep, ".") module = importlib.import_module(module_path) @@ -179,7 +213,7 @@ def get_cached_module_file( repo_type: Optional[str] = None, _commit_hash: Optional[str] = None, **deprecated_kwargs, -): +) -> str: """ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached Transformers module. @@ -354,7 +388,7 @@ def get_class_from_dynamic_module( repo_type: Optional[str] = None, code_revision: Optional[str] = None, **kwargs, -): +) -> typing.Type: """ Extracts a class from a module file, present in the local folder or repository of a model. @@ -416,7 +450,7 @@ def get_class_from_dynamic_module( Returns: - `type`: The class, dynamically imported from the module. + `typing.Type`: The class, dynamically imported from the module. Examples: @@ -463,7 +497,7 @@ def get_class_from_dynamic_module( return get_class_in_module(class_name, final_module.replace(".py", "")) -def custom_object_save(obj, folder, config=None): +def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[Dict] = None) -> List[str]: """ Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally adds the proper fields in a config. @@ -473,6 +507,9 @@ def custom_object_save(obj, folder, config=None): folder (`str` or `os.PathLike`): The folder where to save. config (`PretrainedConfig` or dictionary, `optional`): A config in which to register the auto_map corresponding to this custom object. + + Returns: + `List[str]`: The list of files saved. """ if obj.__module__ == "__main__": logger.warning(