Allow relative imports in dynamic code (#15352)
* Allow dynamic modules to use relative imports * Add tests * Add one last test * Changes
This commit is contained in:
@@ -22,6 +22,8 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
from huggingface_hub import HfFolder, model_info
|
||||||
|
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
HF_MODULES_CACHE,
|
HF_MODULES_CACHE,
|
||||||
TRANSFORMERS_DYNAMIC_MODULE_NAME,
|
TRANSFORMERS_DYNAMIC_MODULE_NAME,
|
||||||
@@ -79,6 +81,12 @@ def check_imports(filename):
|
|||||||
# Only keep the top-level module
|
# Only keep the top-level module
|
||||||
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
|
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
|
||||||
|
|
||||||
|
# Imports of the form `import .xxx`
|
||||||
|
relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
|
||||||
|
# Imports of the form `from .xxx import yyy`
|
||||||
|
relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
|
||||||
|
relative_imports = list(set(relative_imports))
|
||||||
|
|
||||||
# Unique-ify and test we got them all
|
# Unique-ify and test we got them all
|
||||||
imports = list(set(imports))
|
imports = list(set(imports))
|
||||||
missing_packages = []
|
missing_packages = []
|
||||||
@@ -94,6 +102,8 @@ def check_imports(filename):
|
|||||||
f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
|
f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return relative_imports
|
||||||
|
|
||||||
|
|
||||||
def get_class_in_module(class_name, module_path):
|
def get_class_in_module(class_name, module_path):
|
||||||
"""
|
"""
|
||||||
@@ -104,6 +114,145 @@ def get_class_in_module(class_name, module_path):
|
|||||||
return getattr(module, class_name)
|
return getattr(module, class_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cached_module_file(
|
||||||
|
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||||
|
module_file: str,
|
||||||
|
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||||
|
force_download: bool = False,
|
||||||
|
resume_download: bool = False,
|
||||||
|
proxies: Optional[Dict[str, str]] = None,
|
||||||
|
use_auth_token: Optional[Union[bool, str]] = None,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
local_files_only: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
|
||||||
|
Transformers module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||||
|
This can be either:
|
||||||
|
|
||||||
|
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
||||||
|
huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
|
||||||
|
under a user or organization name, like `dbmdz/bert-base-german-cased`.
|
||||||
|
- a path to a *directory* containing a configuration file saved using the
|
||||||
|
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
||||||
|
|
||||||
|
module_file (`str`):
|
||||||
|
The name of the module file containing the class to look for.
|
||||||
|
cache_dir (`str` or `os.PathLike`, *optional*):
|
||||||
|
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
||||||
|
cache should not be used.
|
||||||
|
force_download (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
||||||
|
exist.
|
||||||
|
resume_download (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
|
||||||
|
proxies (`Dict[str, str]`, *optional*):
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||||
|
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||||
|
use_auth_token (`str` or *bool*, *optional*):
|
||||||
|
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||||
|
when running `transformers-cli login` (stored in `~/.huggingface`).
|
||||||
|
revision(`str`, *optional*, defaults to `"main"`):
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
|
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||||
|
identifier allowed by git.
|
||||||
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||||
|
If `True`, will only try to load the tokenizer configuration from local files.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Passing `use_auth_token=True` is required when you want to use a private model.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`str`: The path to the module inside the cache."""
|
||||||
|
if is_offline_mode() and not local_files_only:
|
||||||
|
logger.info("Offline mode: forcing local_files_only=True")
|
||||||
|
local_files_only = True
|
||||||
|
|
||||||
|
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
||||||
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
|
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
|
||||||
|
submodule = "local"
|
||||||
|
else:
|
||||||
|
module_file_or_url = hf_bucket_url(
|
||||||
|
pretrained_model_name_or_path, filename=module_file, revision=revision, mirror=None
|
||||||
|
)
|
||||||
|
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load from URL or cache if already cached
|
||||||
|
resolved_module_file = cached_path(
|
||||||
|
module_file_or_url,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
except EnvironmentError:
|
||||||
|
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Check we have all the requirements in our environment
|
||||||
|
modules_needed = check_imports(resolved_module_file)
|
||||||
|
|
||||||
|
# Now we move the module inside our cached dynamic modules.
|
||||||
|
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
|
||||||
|
create_dynamic_module(full_submodule)
|
||||||
|
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
|
||||||
|
if submodule == "local":
|
||||||
|
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
|
||||||
|
# that hash, to only copy when there is a modification but it seems overkill for now).
|
||||||
|
# The only reason we do the copy is to avoid putting too many folders in sys.path.
|
||||||
|
shutil.copy(resolved_module_file, submodule_path / module_file)
|
||||||
|
for module_needed in modules_needed:
|
||||||
|
module_needed = f"{module_needed}.py"
|
||||||
|
shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
|
||||||
|
else:
|
||||||
|
# Get the commit hash
|
||||||
|
# TODO: we will get this info in the etag soon, so retrieve it from there.
|
||||||
|
if isinstance(use_auth_token, str):
|
||||||
|
token = use_auth_token
|
||||||
|
elif use_auth_token is True:
|
||||||
|
token = HfFolder.get_token()
|
||||||
|
else:
|
||||||
|
token = None
|
||||||
|
|
||||||
|
commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha
|
||||||
|
|
||||||
|
# The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
|
||||||
|
# benefit of versioning.
|
||||||
|
submodule_path = submodule_path / commit_hash
|
||||||
|
full_submodule = full_submodule + os.path.sep + commit_hash
|
||||||
|
create_dynamic_module(full_submodule)
|
||||||
|
|
||||||
|
if not (submodule_path / module_file).exists():
|
||||||
|
shutil.copy(resolved_module_file, submodule_path / module_file)
|
||||||
|
# Make sure we also have every file with relative
|
||||||
|
for module_needed in modules_needed:
|
||||||
|
if not (submodule_path / module_needed).exists():
|
||||||
|
get_cached_module_file(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
f"{module_needed}.py",
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
revision=revision,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
return os.path.join(full_submodule, module_file)
|
||||||
|
|
||||||
|
|
||||||
def get_class_from_dynamic_module(
|
def get_class_from_dynamic_module(
|
||||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||||
module_file: str,
|
module_file: str,
|
||||||
@@ -178,58 +327,16 @@ def get_class_from_dynamic_module(
|
|||||||
# module.
|
# module.
|
||||||
cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
|
cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
|
||||||
```"""
|
```"""
|
||||||
if is_offline_mode() and not local_files_only:
|
|
||||||
logger.info("Offline mode: forcing local_files_only=True")
|
|
||||||
local_files_only = True
|
|
||||||
|
|
||||||
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
|
||||||
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
|
|
||||||
submodule = "local"
|
|
||||||
else:
|
|
||||||
module_file_or_url = hf_bucket_url(
|
|
||||||
pretrained_model_name_or_path, filename=module_file, revision=revision, mirror=None
|
|
||||||
)
|
|
||||||
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load from URL or cache if already cached
|
|
||||||
resolved_module_file = cached_path(
|
|
||||||
module_file_or_url,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
resume_download=resume_download,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
except EnvironmentError:
|
|
||||||
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Check we have all the requirements in our environment
|
|
||||||
check_imports(resolved_module_file)
|
|
||||||
|
|
||||||
# Now we move the module inside our cached dynamic modules.
|
|
||||||
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
|
|
||||||
create_dynamic_module(full_submodule)
|
|
||||||
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
|
|
||||||
if submodule == "local":
|
|
||||||
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
|
|
||||||
# that hash, to only copy when there is a modification but it seems overkill for now).
|
|
||||||
# The only reason we do the copy is to avoid putting too many folders in sys.path.
|
|
||||||
module_name = module_file
|
|
||||||
shutil.copy(resolved_module_file, submodule_path / module_file)
|
|
||||||
else:
|
|
||||||
# The module file will end up being named module_file + the etag. This way we get the benefit of versioning.
|
|
||||||
resolved_module_file_name = Path(resolved_module_file).name
|
|
||||||
module_name_parts = [module_file.replace(".py", "")] + resolved_module_file_name.split(".")
|
|
||||||
module_name = "_".join(module_name_parts) + ".py"
|
|
||||||
if not (submodule_path / module_name).exists():
|
|
||||||
shutil.copy(resolved_module_file, submodule_path / module_name)
|
|
||||||
|
|
||||||
# And lastly we get the class inside our newly created module
|
# And lastly we get the class inside our newly created module
|
||||||
final_module = os.path.join(full_submodule, module_name.replace(".py", ""))
|
final_module = get_cached_module_file(
|
||||||
return get_class_in_module(class_name, final_module)
|
pretrained_model_name_or_path,
|
||||||
|
module_file,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
revision=revision,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
return get_class_in_module(class_name, final_module.replace(".py", ""))
|
||||||
|
|||||||
@@ -102,3 +102,7 @@ class AutoConfigTest(unittest.TestCase):
|
|||||||
"hf-internal-testing/no-config-test-repo does not appear to have a file named config.json.",
|
"hf-internal-testing/no-config-test-repo does not appear to have a file named config.json.",
|
||||||
):
|
):
|
||||||
_ = AutoConfig.from_pretrained("hf-internal-testing/no-config-test-repo")
|
_ = AutoConfig.from_pretrained("hf-internal-testing/no-config-test-repo")
|
||||||
|
|
||||||
|
def test_from_pretrained_dynamic_config(self):
|
||||||
|
config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True)
|
||||||
|
self.assertEqual(config.__class__.__name__, "NewModelConfig")
|
||||||
|
|||||||
@@ -324,7 +324,7 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
for child, parent in [(a, b) for a in child_model for b in parent_model]:
|
for child, parent in [(a, b) for a in child_model for b in parent_model]:
|
||||||
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"
|
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"
|
||||||
|
|
||||||
def test_from_pretrained_dynamic_model(self):
|
def test_from_pretrained_dynamic_model_local(self):
|
||||||
config = BertConfig(
|
config = BertConfig(
|
||||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
)
|
)
|
||||||
@@ -340,6 +340,14 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
self.assertTrue(torch.equal(p1, p2))
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
|
def test_from_pretrained_dynamic_model_distant(self):
|
||||||
|
model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True)
|
||||||
|
self.assertEqual(model.__class__.__name__, "NewModel")
|
||||||
|
|
||||||
|
# This one uses a relative import to a util file, this checks it is downloaded and used properly.
|
||||||
|
model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model_with_util", trust_remote_code=True)
|
||||||
|
self.assertEqual(model.__class__.__name__, "NewModel")
|
||||||
|
|
||||||
def test_new_model_registration(self):
|
def test_new_model_registration(self):
|
||||||
AutoConfig.register("new-model", NewModelConfig)
|
AutoConfig.register("new-model", NewModelConfig)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user