Add documentation to dynamic module utils (#25534)
* Add documentation to dynamic module utils * Address review comments
This commit is contained in:
@@ -20,9 +20,10 @@ import re
|
|||||||
import shutil
|
import shutil
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
import typing
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
HF_MODULES_CACHE,
|
HF_MODULES_CACHE,
|
||||||
@@ -57,6 +58,10 @@ def init_hf_modules():
|
|||||||
def create_dynamic_module(name: Union[str, os.PathLike]):
|
def create_dynamic_module(name: Union[str, os.PathLike]):
|
||||||
"""
|
"""
|
||||||
Creates a dynamic module in the cache directory for modules.
|
Creates a dynamic module in the cache directory for modules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (`str` or `os.PathLike`):
|
||||||
|
The name of the dynamic module to create.
|
||||||
"""
|
"""
|
||||||
init_hf_modules()
|
init_hf_modules()
|
||||||
dynamic_module_path = (Path(HF_MODULES_CACHE) / name).resolve()
|
dynamic_module_path = (Path(HF_MODULES_CACHE) / name).resolve()
|
||||||
@@ -67,15 +72,20 @@ def create_dynamic_module(name: Union[str, os.PathLike]):
|
|||||||
init_path = dynamic_module_path / "__init__.py"
|
init_path = dynamic_module_path / "__init__.py"
|
||||||
if not init_path.exists():
|
if not init_path.exists():
|
||||||
init_path.touch()
|
init_path.touch()
|
||||||
|
# It is extremely important to invalidate the cache when we change stuff in those modules, or users end up
|
||||||
|
# with errors about module that do not exist. Same for all other `invalidate_caches` in this file.
|
||||||
importlib.invalidate_caches()
|
importlib.invalidate_caches()
|
||||||
|
|
||||||
|
|
||||||
def get_relative_imports(module_file):
|
def get_relative_imports(module_file: Union[str, os.PathLike]) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Get the list of modules that are relatively imported in a module file.
|
Get the list of modules that are relatively imported in a module file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module_file (`str` or `os.PathLike`): The module file to inspect.
|
module_file (`str` or `os.PathLike`): The module file to inspect.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[str]`: The list of relative imports in the module.
|
||||||
"""
|
"""
|
||||||
with open(module_file, "r", encoding="utf-8") as f:
|
with open(module_file, "r", encoding="utf-8") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
@@ -88,13 +98,17 @@ def get_relative_imports(module_file):
|
|||||||
return list(set(relative_imports))
|
return list(set(relative_imports))
|
||||||
|
|
||||||
|
|
||||||
def get_relative_import_files(module_file):
|
def get_relative_import_files(module_file: Union[str, os.PathLike]) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Get the list of all files that are needed for a given module. Note that this function recurses through the relative
|
Get the list of all files that are needed for a given module. Note that this function recurses through the relative
|
||||||
imports (if a imports b and b imports c, it will return module files for b and c).
|
imports (if a imports b and b imports c, it will return module files for b and c).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module_file (`str` or `os.PathLike`): The module file to inspect.
|
module_file (`str` or `os.PathLike`): The module file to inspect.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[str]`: The list of all relative imports a given module needs (recursively), which will give us the list
|
||||||
|
of module files a given module needs.
|
||||||
"""
|
"""
|
||||||
no_change = False
|
no_change = False
|
||||||
files_to_check = [module_file]
|
files_to_check = [module_file]
|
||||||
@@ -117,9 +131,15 @@ def get_relative_import_files(module_file):
|
|||||||
return all_relative_imports
|
return all_relative_imports
|
||||||
|
|
||||||
|
|
||||||
def get_imports(filename):
|
def get_imports(filename: Union[str, os.PathLike]) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Extracts all the libraries that are imported in a file.
|
Extracts all the libraries (not relative imports this time) that are imported in a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename (`str` or `os.PathLike`): The module file to inspect.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[str]`: The list of all packages required to use the input module.
|
||||||
"""
|
"""
|
||||||
with open(filename, "r", encoding="utf-8") as f:
|
with open(filename, "r", encoding="utf-8") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
@@ -136,9 +156,16 @@ def get_imports(filename):
|
|||||||
return list(set(imports))
|
return list(set(imports))
|
||||||
|
|
||||||
|
|
||||||
def check_imports(filename):
|
def check_imports(filename: Union[str, os.PathLike]) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Check if the current Python environment contains all the libraries that are imported in a file.
|
Check if the current Python environment contains all the libraries that are imported in a file. Will raise if a
|
||||||
|
library is missing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename (`str` or `os.PathLike`): The module file to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[str]`: The list of relative imports in the file.
|
||||||
"""
|
"""
|
||||||
imports = get_imports(filename)
|
imports = get_imports(filename)
|
||||||
missing_packages = []
|
missing_packages = []
|
||||||
@@ -157,9 +184,16 @@ def check_imports(filename):
|
|||||||
return get_relative_imports(filename)
|
return get_relative_imports(filename)
|
||||||
|
|
||||||
|
|
||||||
def get_class_in_module(class_name, module_path):
|
def get_class_in_module(class_name: str, module_path: Union[str, os.PathLike]) -> typing.Type:
|
||||||
"""
|
"""
|
||||||
Import a module on the cache directory for modules and extract a class from it.
|
Import a module on the cache directory for modules and extract a class from it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_name (`str`): The name of the class to import.
|
||||||
|
module_path (`str` or `os.PathLike`): The path to the module to import.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`typing.Type`: The class looked for.
|
||||||
"""
|
"""
|
||||||
module_path = module_path.replace(os.path.sep, ".")
|
module_path = module_path.replace(os.path.sep, ".")
|
||||||
module = importlib.import_module(module_path)
|
module = importlib.import_module(module_path)
|
||||||
@@ -179,7 +213,7 @@ def get_cached_module_file(
|
|||||||
repo_type: Optional[str] = None,
|
repo_type: Optional[str] = None,
|
||||||
_commit_hash: Optional[str] = None,
|
_commit_hash: Optional[str] = None,
|
||||||
**deprecated_kwargs,
|
**deprecated_kwargs,
|
||||||
):
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
|
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
|
||||||
Transformers module.
|
Transformers module.
|
||||||
@@ -354,7 +388,7 @@ def get_class_from_dynamic_module(
|
|||||||
repo_type: Optional[str] = None,
|
repo_type: Optional[str] = None,
|
||||||
code_revision: Optional[str] = None,
|
code_revision: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
) -> typing.Type:
|
||||||
"""
|
"""
|
||||||
Extracts a class from a module file, present in the local folder or repository of a model.
|
Extracts a class from a module file, present in the local folder or repository of a model.
|
||||||
|
|
||||||
@@ -416,7 +450,7 @@ def get_class_from_dynamic_module(
|
|||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`type`: The class, dynamically imported from the module.
|
`typing.Type`: The class, dynamically imported from the module.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@@ -463,7 +497,7 @@ def get_class_from_dynamic_module(
|
|||||||
return get_class_in_module(class_name, final_module.replace(".py", ""))
|
return get_class_in_module(class_name, final_module.replace(".py", ""))
|
||||||
|
|
||||||
|
|
||||||
def custom_object_save(obj, folder, config=None):
|
def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[Dict] = None) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally
|
Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally
|
||||||
adds the proper fields in a config.
|
adds the proper fields in a config.
|
||||||
@@ -473,6 +507,9 @@ def custom_object_save(obj, folder, config=None):
|
|||||||
folder (`str` or `os.PathLike`): The folder where to save.
|
folder (`str` or `os.PathLike`): The folder where to save.
|
||||||
config (`PretrainedConfig` or dictionary, `optional`):
|
config (`PretrainedConfig` or dictionary, `optional`):
|
||||||
A config in which to register the auto_map corresponding to this custom object.
|
A config in which to register the auto_map corresponding to this custom object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[str]`: The list of files saved.
|
||||||
"""
|
"""
|
||||||
if obj.__module__ == "__main__":
|
if obj.__module__ == "__main__":
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
Reference in New Issue
Block a user