Use code on the Hub from another repo (#22698)
* initial work * Add other classes * Refactor code * Move warning and fix dynamic pipeline * Issue warning when necessary * Add test
This commit is contained in:
@@ -29,6 +29,7 @@ from .utils import (
|
||||
extract_commit_hash,
|
||||
is_offline_mode,
|
||||
logging,
|
||||
try_to_load_from_cache,
|
||||
)
|
||||
|
||||
|
||||
@@ -222,11 +223,16 @@ def get_cached_module_file(
|
||||
|
||||
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
if is_local:
|
||||
submodule = pretrained_model_name_or_path.split(os.path.sep)[-1]
|
||||
else:
|
||||
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
|
||||
cached_module = try_to_load_from_cache(
|
||||
pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash
|
||||
)
|
||||
|
||||
new_files = []
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_module_file = cached_file(
|
||||
@@ -241,6 +247,8 @@ def get_cached_module_file(
|
||||
revision=revision,
|
||||
_commit_hash=_commit_hash,
|
||||
)
|
||||
if not is_local and cached_module != resolved_module_file:
|
||||
new_files.append(module_file)
|
||||
|
||||
except EnvironmentError:
|
||||
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
||||
@@ -284,7 +292,7 @@ def get_cached_module_file(
|
||||
importlib.invalidate_caches()
|
||||
# Make sure we also have every file with relative
|
||||
for module_needed in modules_needed:
|
||||
if not (submodule_path / module_needed).exists():
|
||||
if not (submodule_path / f"{module_needed}.py").exists():
|
||||
get_cached_module_file(
|
||||
pretrained_model_name_or_path,
|
||||
f"{module_needed}.py",
|
||||
@@ -295,14 +303,24 @@ def get_cached_module_file(
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
_commit_hash=commit_hash,
|
||||
)
|
||||
new_files.append(f"{module_needed}.py")
|
||||
|
||||
if len(new_files) > 0:
|
||||
new_files = "\n".join([f"- {f}" for f in new_files])
|
||||
logger.warning(
|
||||
f"A new version of the following files was downloaded from {pretrained_model_name_or_path}:\n{new_files}"
|
||||
"\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new "
|
||||
"versions of the code file, you can pin a revision."
|
||||
)
|
||||
|
||||
return os.path.join(full_submodule, module_file)
|
||||
|
||||
|
||||
def get_class_from_dynamic_module(
|
||||
class_reference: str,
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
module_file: str,
|
||||
class_name: str,
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool = False,
|
||||
@@ -323,6 +341,8 @@ def get_class_from_dynamic_module(
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
class_reference (`str`):
|
||||
The full name of the class to load, including its module and optionally its repo.
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||
This can be either:
|
||||
|
||||
@@ -332,6 +352,7 @@ def get_class_from_dynamic_module(
|
||||
- a path to a *directory* containing a configuration file saved using the
|
||||
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
||||
|
||||
This is used when `class_reference` does not specify another repo.
|
||||
module_file (`str`):
|
||||
The name of the module file containing the class to look for.
|
||||
class_name (`str`):
|
||||
@@ -371,12 +392,25 @@ def get_class_from_dynamic_module(
|
||||
```python
|
||||
# Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
|
||||
# module.
|
||||
cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
|
||||
cls = get_class_from_dynamic_module("modeling.MyBertModel", "sgugger/my-bert-model")
|
||||
|
||||
# Download module `modeling.py` from a given repo and cache then extract the class `MyBertModel` from this
|
||||
# module.
|
||||
cls = get_class_from_dynamic_module("sgugger/my-bert-model--modeling.MyBertModel", "sgugger/another-bert-model")
|
||||
```"""
|
||||
# Catch the name of the repo if it's specified in `class_reference`
|
||||
if "--" in class_reference:
|
||||
repo_id, class_reference = class_reference.split("--")
|
||||
# Invalidate revision since it's not relevant for this repo
|
||||
revision = "main"
|
||||
else:
|
||||
repo_id = pretrained_model_name_or_path
|
||||
module_file, class_name = class_reference.split(".")
|
||||
|
||||
# And lastly we get the class inside our newly created module
|
||||
final_module = get_cached_module_file(
|
||||
pretrained_model_name_or_path,
|
||||
module_file,
|
||||
repo_id,
|
||||
module_file + ".py",
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
|
||||
Reference in New Issue
Block a user