Enable code-specific revision for code on the Hub (#23799)
* Enable code-specific revision for code on the Hub * invalidate old revision
This commit is contained in:
@@ -316,7 +316,7 @@ def get_cached_module_file(
|
||||
)
|
||||
new_files.append(f"{module_needed}.py")
|
||||
|
||||
if len(new_files) > 0:
|
||||
if len(new_files) > 0 and revision is None:
|
||||
new_files = "\n".join([f"- {f}" for f in new_files])
|
||||
repo_type_str = "" if repo_type is None else f"{repo_type}s/"
|
||||
url = f"https://huggingface.co/{repo_type_str}{pretrained_model_name_or_path}"
|
||||
@@ -340,6 +340,7 @@ def get_class_from_dynamic_module(
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
repo_type: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -391,6 +392,10 @@ def get_class_from_dynamic_module(
|
||||
If `True`, will only try to load the tokenizer configuration from local files.
|
||||
repo_type (`str`, *optional*):
|
||||
Specify the repo type (useful when downloading from a space for instance).
|
||||
code_revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific revision to use for the code on the Hub, if the code leaves in a different repository than the
|
||||
rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for
|
||||
storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -415,12 +420,12 @@ def get_class_from_dynamic_module(
|
||||
# Catch the name of the repo if it's specified in `class_reference`
|
||||
if "--" in class_reference:
|
||||
repo_id, class_reference = class_reference.split("--")
|
||||
# Invalidate revision since it's not relevant for this repo
|
||||
revision = "main"
|
||||
else:
|
||||
repo_id = pretrained_model_name_or_path
|
||||
module_file, class_name = class_reference.split(".")
|
||||
|
||||
if code_revision is None and pretrained_model_name_or_path == repo_id:
|
||||
code_revision = revision
|
||||
# And lastly we get the class inside our newly created module
|
||||
final_module = get_cached_module_file(
|
||||
repo_id,
|
||||
@@ -430,7 +435,7 @@ def get_class_from_dynamic_module(
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
revision=code_revision,
|
||||
local_files_only=local_files_only,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user