Revert "Use code on the Hub from another repo" (#22813)
Revert "Use code on the Hub from another repo (#22698)"
This reverts commit ea7b0a539a.
This commit is contained in:
@@ -667,11 +667,6 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
else:
|
else:
|
||||||
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
|
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
|
||||||
|
|
||||||
if "auto_map" in config_dict and not is_local:
|
|
||||||
config_dict["auto_map"] = {
|
|
||||||
k: (f"{pretrained_model_name_or_path}--{v}" if "--" not in v else v)
|
|
||||||
for k, v in config_dict["auto_map"].items()
|
|
||||||
}
|
|
||||||
return config_dict, kwargs
|
return config_dict, kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ from .utils import (
|
|||||||
extract_commit_hash,
|
extract_commit_hash,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
logging,
|
logging,
|
||||||
try_to_load_from_cache,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -223,16 +222,11 @@ def get_cached_module_file(
|
|||||||
|
|
||||||
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
if is_local:
|
|
||||||
submodule = pretrained_model_name_or_path.split(os.path.sep)[-1]
|
submodule = pretrained_model_name_or_path.split(os.path.sep)[-1]
|
||||||
else:
|
else:
|
||||||
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
|
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
|
||||||
cached_module = try_to_load_from_cache(
|
|
||||||
pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash
|
|
||||||
)
|
|
||||||
|
|
||||||
new_files = []
|
|
||||||
try:
|
try:
|
||||||
# Load from URL or cache if already cached
|
# Load from URL or cache if already cached
|
||||||
resolved_module_file = cached_file(
|
resolved_module_file = cached_file(
|
||||||
@@ -247,8 +241,6 @@ def get_cached_module_file(
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
_commit_hash=_commit_hash,
|
_commit_hash=_commit_hash,
|
||||||
)
|
)
|
||||||
if not is_local and cached_module != resolved_module_file:
|
|
||||||
new_files.append(module_file)
|
|
||||||
|
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
||||||
@@ -292,7 +284,7 @@ def get_cached_module_file(
|
|||||||
importlib.invalidate_caches()
|
importlib.invalidate_caches()
|
||||||
# Make sure we also have every file with relative
|
# Make sure we also have every file with relative
|
||||||
for module_needed in modules_needed:
|
for module_needed in modules_needed:
|
||||||
if not (submodule_path / f"{module_needed}.py").exists():
|
if not (submodule_path / module_needed).exists():
|
||||||
get_cached_module_file(
|
get_cached_module_file(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
f"{module_needed}.py",
|
f"{module_needed}.py",
|
||||||
@@ -303,24 +295,14 @@ def get_cached_module_file(
|
|||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
_commit_hash=commit_hash,
|
|
||||||
)
|
)
|
||||||
new_files.append(f"{module_needed}.py")
|
|
||||||
|
|
||||||
if len(new_files) > 0:
|
|
||||||
new_files = "\n".join([f"- {f}" for f in new_files])
|
|
||||||
logger.warning(
|
|
||||||
f"A new version of the following files was downloaded from {pretrained_model_name_or_path}:\n{new_files}"
|
|
||||||
"\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new "
|
|
||||||
"versions of the code file, you can pin a revision."
|
|
||||||
)
|
|
||||||
|
|
||||||
return os.path.join(full_submodule, module_file)
|
return os.path.join(full_submodule, module_file)
|
||||||
|
|
||||||
|
|
||||||
def get_class_from_dynamic_module(
|
def get_class_from_dynamic_module(
|
||||||
class_reference: str,
|
|
||||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||||
|
module_file: str,
|
||||||
|
class_name: str,
|
||||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||||
force_download: bool = False,
|
force_download: bool = False,
|
||||||
resume_download: bool = False,
|
resume_download: bool = False,
|
||||||
@@ -341,8 +323,6 @@ def get_class_from_dynamic_module(
|
|||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
class_reference (`str`):
|
|
||||||
The full name of the class to load, including its module and optionally its repo.
|
|
||||||
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||||
This can be either:
|
This can be either:
|
||||||
|
|
||||||
@@ -352,7 +332,6 @@ def get_class_from_dynamic_module(
|
|||||||
- a path to a *directory* containing a configuration file saved using the
|
- a path to a *directory* containing a configuration file saved using the
|
||||||
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
||||||
|
|
||||||
This is used when `class_reference` does not specify another repo.
|
|
||||||
module_file (`str`):
|
module_file (`str`):
|
||||||
The name of the module file containing the class to look for.
|
The name of the module file containing the class to look for.
|
||||||
class_name (`str`):
|
class_name (`str`):
|
||||||
@@ -392,25 +371,12 @@ def get_class_from_dynamic_module(
|
|||||||
```python
|
```python
|
||||||
# Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
|
# Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
|
||||||
# module.
|
# module.
|
||||||
cls = get_class_from_dynamic_module("modeling.MyBertModel", "sgugger/my-bert-model")
|
cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
|
||||||
|
|
||||||
# Download module `modeling.py` from a given repo and cache then extract the class `MyBertModel` from this
|
|
||||||
# module.
|
|
||||||
cls = get_class_from_dynamic_module("sgugger/my-bert-model--modeling.MyBertModel", "sgugger/another-bert-model")
|
|
||||||
```"""
|
```"""
|
||||||
# Catch the name of the repo if it's specified in `class_reference`
|
|
||||||
if "--" in class_reference:
|
|
||||||
repo_id, class_reference = class_reference.split("--")
|
|
||||||
# Invalidate revision since it's not relevant for this repo
|
|
||||||
revision = "main"
|
|
||||||
else:
|
|
||||||
repo_id = pretrained_model_name_or_path
|
|
||||||
module_file, class_name = class_reference.split(".")
|
|
||||||
|
|
||||||
# And lastly we get the class inside our newly created module
|
# And lastly we get the class inside our newly created module
|
||||||
final_module = get_cached_module_file(
|
final_module = get_cached_module_file(
|
||||||
repo_id,
|
pretrained_model_name_or_path,
|
||||||
module_file + ".py",
|
module_file,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
force_download=force_download,
|
force_download=force_download,
|
||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
|
|||||||
@@ -403,12 +403,8 @@ class _BaseAutoModelClass:
|
|||||||
"no malicious code has been contributed in a newer revision."
|
"no malicious code has been contributed in a newer revision."
|
||||||
)
|
)
|
||||||
class_ref = config.auto_map[cls.__name__]
|
class_ref = config.auto_map[cls.__name__]
|
||||||
if "--" in class_ref:
|
|
||||||
repo_id, class_ref = class_ref.split("--")
|
|
||||||
else:
|
|
||||||
repo_id = config.name_or_path
|
|
||||||
module_file, class_name = class_ref.split(".")
|
module_file, class_name = class_ref.split(".")
|
||||||
model_class = get_class_from_dynamic_module(repo_id, module_file + ".py", class_name, **kwargs)
|
model_class = get_class_from_dynamic_module(config.name_or_path, module_file + ".py", class_name, **kwargs)
|
||||||
return model_class._from_config(config, **kwargs)
|
return model_class._from_config(config, **kwargs)
|
||||||
elif type(config) in cls._model_mapping.keys():
|
elif type(config) in cls._model_mapping.keys():
|
||||||
model_class = _get_model_class(config, cls._model_mapping)
|
model_class = _get_model_class(config, cls._model_mapping)
|
||||||
@@ -456,10 +452,17 @@ class _BaseAutoModelClass:
|
|||||||
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
|
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
|
||||||
"the option `trust_remote_code=True` to remove this error."
|
"the option `trust_remote_code=True` to remove this error."
|
||||||
)
|
)
|
||||||
|
if hub_kwargs.get("revision", None) is None:
|
||||||
|
logger.warning(
|
||||||
|
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
|
||||||
|
"no malicious code has been contributed in a newer revision."
|
||||||
|
)
|
||||||
class_ref = config.auto_map[cls.__name__]
|
class_ref = config.auto_map[cls.__name__]
|
||||||
|
module_file, class_name = class_ref.split(".")
|
||||||
model_class = get_class_from_dynamic_module(
|
model_class = get_class_from_dynamic_module(
|
||||||
class_ref, pretrained_model_name_or_path, **hub_kwargs, **kwargs
|
pretrained_model_name_or_path, module_file + ".py", class_name, **hub_kwargs, **kwargs
|
||||||
)
|
)
|
||||||
|
model_class.register_for_auto_class(cls.__name__)
|
||||||
return model_class.from_pretrained(
|
return model_class.from_pretrained(
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
|
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -921,8 +921,17 @@ class AutoConfig:
|
|||||||
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
|
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
|
||||||
" set the option `trust_remote_code=True` to remove this error."
|
" set the option `trust_remote_code=True` to remove this error."
|
||||||
)
|
)
|
||||||
|
if kwargs.get("revision", None) is None:
|
||||||
|
logger.warning(
|
||||||
|
"Explicitly passing a `revision` is encouraged when loading a configuration with custom code to "
|
||||||
|
"ensure no malicious code has been contributed in a newer revision."
|
||||||
|
)
|
||||||
class_ref = config_dict["auto_map"]["AutoConfig"]
|
class_ref = config_dict["auto_map"]["AutoConfig"]
|
||||||
config_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
|
module_file, class_name = class_ref.split(".")
|
||||||
|
config_class = get_class_from_dynamic_module(
|
||||||
|
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
|
||||||
|
)
|
||||||
|
config_class.register_for_auto_class()
|
||||||
return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
elif "model_type" in config_dict:
|
elif "model_type" in config_dict:
|
||||||
config_class = CONFIG_MAPPING[config_dict["model_type"]]
|
config_class = CONFIG_MAPPING[config_dict["model_type"]]
|
||||||
|
|||||||
@@ -333,9 +333,17 @@ class AutoFeatureExtractor:
|
|||||||
"in that repo on your local machine. Make sure you have read the code there to avoid "
|
"in that repo on your local machine. Make sure you have read the code there to avoid "
|
||||||
"malicious use, then set the option `trust_remote_code=True` to remove this error."
|
"malicious use, then set the option `trust_remote_code=True` to remove this error."
|
||||||
)
|
)
|
||||||
|
if kwargs.get("revision", None) is None:
|
||||||
|
logger.warning(
|
||||||
|
"Explicitly passing a `revision` is encouraged when loading a feature extractor with custom "
|
||||||
|
"code to ensure no malicious code has been contributed in a newer revision."
|
||||||
|
)
|
||||||
|
|
||||||
|
module_file, class_name = feature_extractor_auto_map.split(".")
|
||||||
feature_extractor_class = get_class_from_dynamic_module(
|
feature_extractor_class = get_class_from_dynamic_module(
|
||||||
feature_extractor_auto_map, pretrained_model_name_or_path, **kwargs
|
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
|
||||||
)
|
)
|
||||||
|
feature_extractor_class.register_for_auto_class()
|
||||||
else:
|
else:
|
||||||
feature_extractor_class = feature_extractor_class_from_name(feature_extractor_class)
|
feature_extractor_class = feature_extractor_class_from_name(feature_extractor_class)
|
||||||
|
|
||||||
|
|||||||
@@ -355,9 +355,17 @@ class AutoImageProcessor:
|
|||||||
"in that repo on your local machine. Make sure you have read the code there to avoid "
|
"in that repo on your local machine. Make sure you have read the code there to avoid "
|
||||||
"malicious use, then set the option `trust_remote_code=True` to remove this error."
|
"malicious use, then set the option `trust_remote_code=True` to remove this error."
|
||||||
)
|
)
|
||||||
|
if kwargs.get("revision", None) is None:
|
||||||
|
logger.warning(
|
||||||
|
"Explicitly passing a `revision` is encouraged when loading a image processor with custom "
|
||||||
|
"code to ensure no malicious code has been contributed in a newer revision."
|
||||||
|
)
|
||||||
|
|
||||||
|
module_file, class_name = image_processor_auto_map.split(".")
|
||||||
image_processor_class = get_class_from_dynamic_module(
|
image_processor_class = get_class_from_dynamic_module(
|
||||||
image_processor_auto_map, pretrained_model_name_or_path, **kwargs
|
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
|
||||||
)
|
)
|
||||||
|
image_processor_class.register_for_auto_class()
|
||||||
else:
|
else:
|
||||||
image_processor_class = image_processor_class_from_name(image_processor_class)
|
image_processor_class = image_processor_class_from_name(image_processor_class)
|
||||||
|
|
||||||
|
|||||||
@@ -254,10 +254,17 @@ class AutoProcessor:
|
|||||||
"in that repo on your local machine. Make sure you have read the code there to avoid "
|
"in that repo on your local machine. Make sure you have read the code there to avoid "
|
||||||
"malicious use, then set the option `trust_remote_code=True` to remove this error."
|
"malicious use, then set the option `trust_remote_code=True` to remove this error."
|
||||||
)
|
)
|
||||||
|
if kwargs.get("revision", None) is None:
|
||||||
|
logger.warning(
|
||||||
|
"Explicitly passing a `revision` is encouraged when loading a feature extractor with custom "
|
||||||
|
"code to ensure no malicious code has been contributed in a newer revision."
|
||||||
|
)
|
||||||
|
|
||||||
|
module_file, class_name = processor_auto_map.split(".")
|
||||||
processor_class = get_class_from_dynamic_module(
|
processor_class = get_class_from_dynamic_module(
|
||||||
processor_auto_map, pretrained_model_name_or_path, **kwargs
|
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
|
||||||
)
|
)
|
||||||
|
processor_class.register_for_auto_class()
|
||||||
else:
|
else:
|
||||||
processor_class = processor_class_from_name(processor_class)
|
processor_class = processor_class_from_name(processor_class)
|
||||||
|
|
||||||
|
|||||||
@@ -671,12 +671,22 @@ class AutoTokenizer:
|
|||||||
" repo on your local machine. Make sure you have read the code there to avoid malicious use,"
|
" repo on your local machine. Make sure you have read the code there to avoid malicious use,"
|
||||||
" then set the option `trust_remote_code=True` to remove this error."
|
" then set the option `trust_remote_code=True` to remove this error."
|
||||||
)
|
)
|
||||||
|
if kwargs.get("revision", None) is None:
|
||||||
|
logger.warning(
|
||||||
|
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure"
|
||||||
|
" no malicious code has been contributed in a newer revision."
|
||||||
|
)
|
||||||
|
|
||||||
if use_fast and tokenizer_auto_map[1] is not None:
|
if use_fast and tokenizer_auto_map[1] is not None:
|
||||||
class_ref = tokenizer_auto_map[1]
|
class_ref = tokenizer_auto_map[1]
|
||||||
else:
|
else:
|
||||||
class_ref = tokenizer_auto_map[0]
|
class_ref = tokenizer_auto_map[0]
|
||||||
tokenizer_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
|
|
||||||
|
module_file, class_name = class_ref.split(".")
|
||||||
|
tokenizer_class = get_class_from_dynamic_module(
|
||||||
|
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
|
||||||
|
)
|
||||||
|
tokenizer_class.register_for_auto_class()
|
||||||
|
|
||||||
elif use_fast and not config_tokenizer_class.endswith("Fast"):
|
elif use_fast and not config_tokenizer_class.endswith("Fast"):
|
||||||
tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
|
tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
|
||||||
|
|||||||
@@ -727,8 +727,9 @@ def pipeline(
|
|||||||
" set the option `trust_remote_code=True` to remove this error."
|
" set the option `trust_remote_code=True` to remove this error."
|
||||||
)
|
)
|
||||||
class_ref = targeted_task["impl"]
|
class_ref = targeted_task["impl"]
|
||||||
|
module_file, class_name = class_ref.split(".")
|
||||||
pipeline_class = get_class_from_dynamic_module(
|
pipeline_class = get_class_from_dynamic_module(
|
||||||
class_ref, model, revision=revision, use_auth_token=use_auth_token
|
model, module_file + ".py", class_name, revision=revision, use_auth_token=use_auth_token
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
normalized_task, targeted_task, task_options = check_task(task)
|
normalized_task, targeted_task, task_options = check_task(task)
|
||||||
|
|||||||
@@ -1817,7 +1817,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
_commit_hash=commit_hash,
|
_commit_hash=commit_hash,
|
||||||
_is_local=is_local,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1832,7 +1831,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
cache_dir=None,
|
cache_dir=None,
|
||||||
local_files_only=False,
|
local_files_only=False,
|
||||||
_commit_hash=None,
|
_commit_hash=None,
|
||||||
_is_local=False,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json
|
# We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json
|
||||||
@@ -1863,6 +1861,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
# First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers.
|
# First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers.
|
||||||
config_tokenizer_class = init_kwargs.get("tokenizer_class")
|
config_tokenizer_class = init_kwargs.get("tokenizer_class")
|
||||||
init_kwargs.pop("tokenizer_class", None)
|
init_kwargs.pop("tokenizer_class", None)
|
||||||
|
init_kwargs.pop("auto_map", None)
|
||||||
saved_init_inputs = init_kwargs.pop("init_inputs", ())
|
saved_init_inputs = init_kwargs.pop("init_inputs", ())
|
||||||
if not init_inputs:
|
if not init_inputs:
|
||||||
init_inputs = saved_init_inputs
|
init_inputs = saved_init_inputs
|
||||||
@@ -1870,15 +1869,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
config_tokenizer_class = None
|
config_tokenizer_class = None
|
||||||
init_kwargs = init_configuration
|
init_kwargs = init_configuration
|
||||||
|
|
||||||
if "auto_map" in init_kwargs and not _is_local:
|
|
||||||
new_auto_map = {}
|
|
||||||
for key, value in init_kwargs["auto_map"].items():
|
|
||||||
if isinstance(value, (list, tuple)):
|
|
||||||
new_auto_map[key] = [f"{pretrained_model_name_or_path}--{v}" for v in value]
|
|
||||||
else:
|
|
||||||
new_auto_map[key] = f"{pretrained_model_name_or_path}--{value}"
|
|
||||||
init_kwargs["auto_map"] = new_auto_map
|
|
||||||
|
|
||||||
if config_tokenizer_class is None:
|
if config_tokenizer_class is None:
|
||||||
from .models.auto.configuration_auto import AutoConfig # tests_ignore
|
from .models.auto.configuration_auto import AutoConfig # tests_ignore
|
||||||
|
|
||||||
|
|||||||
@@ -83,7 +83,6 @@ from .hub import (
|
|||||||
is_remote_url,
|
is_remote_url,
|
||||||
move_cache,
|
move_cache,
|
||||||
send_example_telemetry,
|
send_example_telemetry,
|
||||||
try_to_load_from_cache,
|
|
||||||
)
|
)
|
||||||
from .import_utils import (
|
from .import_utils import (
|
||||||
ENV_VARS_TRUE_AND_AUTO_VALUES,
|
ENV_VARS_TRUE_AND_AUTO_VALUES,
|
||||||
|
|||||||
@@ -298,34 +298,6 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
for p1, p2 in zip(model.parameters(), reloaded_model.parameters()):
|
for p1, p2 in zip(model.parameters(), reloaded_model.parameters()):
|
||||||
self.assertTrue(torch.equal(p1, p2))
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
def test_from_pretrained_dynamic_model_distant_with_ref(self):
|
|
||||||
model = AutoModel.from_pretrained("hf-internal-testing/ref_to_test_dynamic_model", trust_remote_code=True)
|
|
||||||
self.assertEqual(model.__class__.__name__, "NewModel")
|
|
||||||
|
|
||||||
# Test model can be reloaded.
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
model.save_pretrained(tmp_dir)
|
|
||||||
reloaded_model = AutoModel.from_pretrained(tmp_dir, trust_remote_code=True)
|
|
||||||
|
|
||||||
self.assertEqual(reloaded_model.__class__.__name__, "NewModel")
|
|
||||||
for p1, p2 in zip(model.parameters(), reloaded_model.parameters()):
|
|
||||||
self.assertTrue(torch.equal(p1, p2))
|
|
||||||
|
|
||||||
# This one uses a relative import to a util file, this checks it is downloaded and used properly.
|
|
||||||
model = AutoModel.from_pretrained(
|
|
||||||
"hf-internal-testing/ref_to_test_dynamic_model_with_util", trust_remote_code=True
|
|
||||||
)
|
|
||||||
self.assertEqual(model.__class__.__name__, "NewModel")
|
|
||||||
|
|
||||||
# Test model can be reloaded.
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
model.save_pretrained(tmp_dir)
|
|
||||||
reloaded_model = AutoModel.from_pretrained(tmp_dir, trust_remote_code=True)
|
|
||||||
|
|
||||||
self.assertEqual(reloaded_model.__class__.__name__, "NewModel")
|
|
||||||
for p1, p2 in zip(model.parameters(), reloaded_model.parameters()):
|
|
||||||
self.assertTrue(torch.equal(p1, p2))
|
|
||||||
|
|
||||||
def test_new_model_registration(self):
|
def test_new_model_registration(self):
|
||||||
AutoConfig.register("custom", CustomConfig)
|
AutoConfig.register("custom", CustomConfig)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user