Dynamically load model code from the Hub (#13467)
* Dynamic model * Use defensive flag * Style * Doc and arg rename * Arg rename * Add tests * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Address review comments * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -248,6 +248,8 @@ if (
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
|
||||
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
|
||||
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
|
||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
|
||||
TRANSFORMERS_DYNAMIC_MODULE_NAME = "transformers_modules"
|
||||
SESSION_ID = uuid4().hex
|
||||
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) in ENV_VARS_TRUE_VALUES
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from ...configuration_utils import PretrainedConfig
|
||||
from ...file_utils import copy_func
|
||||
from ...utils import logging
|
||||
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
|
||||
from .dynamic import get_class_from_dynamic_module
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -122,6 +123,10 @@ FROM_PRETRAINED_TORCH_DOCSTRING = """
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||
identifier allowed by git.
|
||||
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
||||
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
|
||||
will execute code present on the Hub on your local machine.
|
||||
kwargs (additional keyword arguments, `optional`):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
|
||||
@@ -211,6 +216,10 @@ FROM_PRETRAINED_TF_DOCSTRING = """
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||
identifier allowed by git.
|
||||
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
||||
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
|
||||
will execute code present on the Hub on your local machine.
|
||||
kwargs (additional keyword arguments, `optional`):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
|
||||
@@ -300,6 +309,10 @@ FROM_PRETRAINED_FLAX_DOCSTRING = """
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||
identifier allowed by git.
|
||||
trust_remote_code (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
||||
should only be set to :obj:`True` for repositories you trust and in which you have read the code, as it
|
||||
will execute code present on the Hub on your local machine.
|
||||
kwargs (additional keyword arguments, `optional`):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
|
||||
@@ -377,13 +390,31 @@ class _BaseAutoModelClass:
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
config = kwargs.pop("config", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
||||
kwargs["_from_auto"] = True
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
if type(config) in cls._model_mapping.keys():
|
||||
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
|
||||
if not trust_remote_code:
|
||||
raise ValueError(
|
||||
f"Loading {pretrained_model_name_or_path} requires you to execute the modeling file in that repo "
|
||||
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
|
||||
"the option `trust_remote_code=True` to remove this error."
|
||||
)
|
||||
if kwargs.get("revision", None) is None:
|
||||
logger.warn(
|
||||
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
|
||||
"no malicious code has been contributed in a newer revision."
|
||||
)
|
||||
class_ref = config.auto_map[cls.__name__]
|
||||
module_file, class_name = class_ref.split(".")
|
||||
model_class = get_class_from_dynamic_module(
|
||||
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
|
||||
)
|
||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
elif type(config) in cls._model_mapping.keys():
|
||||
model_class = _get_model_class(config, cls._model_mapping)
|
||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
raise ValueError(
|
||||
|
||||
231
src/transformers/models/auto/dynamic.py
Normal file
231
src/transformers/models/auto/dynamic.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities to dynamically load model and tokenizer from the Hub."""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from ...file_utils import (
|
||||
HF_MODULES_CACHE,
|
||||
TRANSFORMERS_DYNAMIC_MODULE_NAME,
|
||||
cached_path,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
)
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def init_hf_modules():
|
||||
"""
|
||||
Creates the cache directory for modules with an init, and adds it to the Python path.
|
||||
"""
|
||||
# This function has already been executed if HF_MODULES_CACHE already is in the Python path.
|
||||
if HF_MODULES_CACHE in sys.path:
|
||||
return
|
||||
|
||||
sys.path.append(HF_MODULES_CACHE)
|
||||
os.makedirs(HF_MODULES_CACHE, exist_ok=True)
|
||||
init_path = Path(HF_MODULES_CACHE) / "__init__.py"
|
||||
if not init_path.exists():
|
||||
init_path.touch()
|
||||
|
||||
|
||||
def create_dynamic_module(name: Union[str, os.PathLike]):
|
||||
"""
|
||||
Creates a dynamic module in the cache directory for modules.
|
||||
"""
|
||||
init_hf_modules()
|
||||
dynamic_module_path = Path(HF_MODULES_CACHE) / name
|
||||
# If the parent module does not exist yet, recursively create it.
|
||||
if not dynamic_module_path.parent.exists():
|
||||
create_dynamic_module(dynamic_module_path.parent)
|
||||
os.makedirs(dynamic_module_path, exist_ok=True)
|
||||
init_path = dynamic_module_path / "__init__.py"
|
||||
if not init_path.exists():
|
||||
init_path.touch()
|
||||
|
||||
|
||||
def check_imports(filename):
|
||||
"""
|
||||
Check if the current Python environment contains all the libraries that are imported in a file.
|
||||
"""
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# Imports of the form `import xxx`
|
||||
imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
|
||||
# Imports of the form `from xxx import yyy`
|
||||
imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
|
||||
# Only keep the top-level module
|
||||
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
|
||||
|
||||
# Unique-ify and test we got them all
|
||||
imports = list(set(imports))
|
||||
missing_packages = []
|
||||
for imp in imports:
|
||||
try:
|
||||
importlib.import_module(imp)
|
||||
except ImportError:
|
||||
missing_packages.append(imp)
|
||||
|
||||
if len(missing_packages) > 0:
|
||||
raise ImportError(
|
||||
"This modeling file requires the following packages that were not found in your environment: "
|
||||
f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
|
||||
)
|
||||
|
||||
|
||||
def get_class_in_module(class_name, module_path):
|
||||
"""
|
||||
Import a module on the cache directory for modules and extract a class from it.
|
||||
"""
|
||||
module_path = module_path.replace(os.path.sep, ".")
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def get_class_from_dynamic_module(
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
module_file: str,
|
||||
class_name: str,
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool = False,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Extracts a class from a module file, present in the local folder or repository of a model.
|
||||
|
||||
.. warning::
|
||||
|
||||
Calling this function will execute the code in the module file found locally or downloaded from the Hub. It
|
||||
should therefore only be called on trusted repos.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
This can be either:
|
||||
|
||||
- a string, the `model id` of a pretrained model configuration hosted inside a model repo on
|
||||
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
|
||||
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- a path to a `directory` containing a configuration file saved using the
|
||||
:func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g., ``./my_model_directory/``.
|
||||
|
||||
module_file (:obj:`str`):
|
||||
The name of the module file containing the class to look for.
|
||||
class_name (:obj:`str`):
|
||||
The name of the class to import in the module.
|
||||
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
||||
cache should not be used.
|
||||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
||||
exist.
|
||||
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
|
||||
proxies (:obj:`Dict[str, str]`, `optional`):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||
use_auth_token (:obj:`str` or `bool`, `optional`):
|
||||
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
|
||||
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
|
||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||
identifier allowed by git.
|
||||
local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If :obj:`True`, will only try to load the tokenizer configuration from local files.
|
||||
|
||||
.. note::
|
||||
|
||||
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
|
||||
|
||||
|
||||
Returns:
|
||||
:obj:`type`: The class, dynamically imported from the module.
|
||||
|
||||
Examples::
|
||||
|
||||
# Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
|
||||
# module.
|
||||
cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
|
||||
"""
|
||||
if is_offline_mode() and not local_files_only:
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
local_files_only = True
|
||||
|
||||
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
|
||||
submodule = "local"
|
||||
else:
|
||||
module_file_or_url = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=module_file, revision=revision, mirror=None
|
||||
)
|
||||
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_module_file = cached_path(
|
||||
module_file_or_url,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
|
||||
except EnvironmentError:
|
||||
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
||||
raise
|
||||
|
||||
# Check we have all the requirements in our environment
|
||||
check_imports(resolved_module_file)
|
||||
|
||||
# Now we move the module inside our cached dynamic modules.
|
||||
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
|
||||
create_dynamic_module(full_submodule)
|
||||
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
|
||||
if submodule == "local":
|
||||
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
|
||||
# that hash, to only copy when there is a modification but it seems overkill for now).
|
||||
# The only reason we do the copy is to avoid putting too many folders in sys.path.
|
||||
module_name = module_file
|
||||
shutil.copy(resolved_module_file, submodule_path / module_file)
|
||||
else:
|
||||
# The module file will end up being named module_file + the etag. This way we get the benefit of versioning.
|
||||
resolved_module_file_name = Path(resolved_module_file).name
|
||||
module_name_parts = [module_file.replace(".py", "")] + resolved_module_file_name.split(".")
|
||||
module_name = "_".join(module_name_parts) + ".py"
|
||||
if not (submodule_path / module_name).exists():
|
||||
shutil.copy(resolved_module_file, submodule_path / module_name)
|
||||
|
||||
# And lastly we get the class inside our newly created module
|
||||
final_module = os.path.join(full_submodule, module_name.replace(".py", ""))
|
||||
return get_class_in_module(class_name, final_module)
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@@ -28,6 +29,8 @@ from transformers.testing_utils import (
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
@@ -51,6 +54,7 @@ if is_torch_available():
|
||||
FunnelModel,
|
||||
GPT2Config,
|
||||
GPT2LMHeadModel,
|
||||
PreTrainedModel,
|
||||
RobertaForMaskedLM,
|
||||
T5Config,
|
||||
T5ForConditionalGeneration,
|
||||
@@ -75,6 +79,44 @@ if is_torch_available():
|
||||
from transformers.models.tapas.modeling_tapas import TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
|
||||
class FakeModel(PreTrainedModel):
|
||||
config_class = BertConfig
|
||||
base_model_prefix = "fake"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
def _init_weights(self, module):
|
||||
pass
|
||||
|
||||
|
||||
# Make sure this is synchronized with the model above.
|
||||
FAKE_MODEL_CODE = """
|
||||
import torch
|
||||
from transformers import BertConfig, PreTrainedModel
|
||||
|
||||
class FakeModel(PreTrainedModel):
|
||||
config_class = BertConfig
|
||||
base_model_prefix = "fake"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
def _init_weights(self, module):
|
||||
pass
|
||||
"""
|
||||
|
||||
|
||||
@require_torch
|
||||
class AutoModelTest(unittest.TestCase):
|
||||
@slow
|
||||
@@ -272,3 +314,19 @@ class AutoModelTest(unittest.TestCase):
|
||||
|
||||
for child, parent in [(a, b) for a in child_model for b in parent_model]:
|
||||
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"
|
||||
|
||||
def test_from_pretrained_dynamic_model(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
config.auto_map = {"AutoModel": "modeling.FakeModel"}
|
||||
model = FakeModel(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
with open(os.path.join(tmp_dir, "modeling.py"), "w") as f:
|
||||
f.write(FAKE_MODEL_CODE)
|
||||
|
||||
new_model = AutoModel.from_pretrained(tmp_dir, trust_remote_code=True)
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
@@ -17,6 +17,7 @@ import copy
|
||||
import gc
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import os.path
|
||||
import random
|
||||
import tempfile
|
||||
@@ -24,7 +25,7 @@ import unittest
|
||||
import warnings
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub import HfApi, Repository
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import AutoModel, AutoModelForSequenceClassification, is_torch_available, logging
|
||||
from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available
|
||||
@@ -1792,6 +1793,44 @@ class ModelUtilsTest(TestCasePlus):
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
|
||||
class FakeModel(PreTrainedModel):
|
||||
config_class = BertConfig
|
||||
base_model_prefix = "fake"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
def _init_weights(self, module):
|
||||
pass
|
||||
|
||||
|
||||
# Make sure this is synchronized with the model above.
|
||||
FAKE_MODEL_CODE = """
|
||||
import torch
|
||||
from transformers import BertConfig, PreTrainedModel
|
||||
|
||||
class FakeModel(PreTrainedModel):
|
||||
config_class = BertConfig
|
||||
base_model_prefix = "fake"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
def _init_weights(self, module):
|
||||
pass
|
||||
"""
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
class ModelPushToHubTester(unittest.TestCase):
|
||||
@@ -1812,6 +1851,11 @@ class ModelPushToHubTester(unittest.TestCase):
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
cls._api.delete_repo(token=cls._token, name="test-dynamic-model")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
def test_push_to_hub(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
@@ -1840,3 +1884,23 @@ class ModelPushToHubTester(unittest.TestCase):
|
||||
new_model = BertModel.from_pretrained("valid_org/test-model-org")
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
def test_push_to_hub_dynamic_model(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
config.auto_map = {"AutoModel": "modeling.FakeModel"}
|
||||
model = FakeModel(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-model", use_auth_token=self._token)
|
||||
model.save_pretrained(tmp_dir)
|
||||
with open(os.path.join(tmp_dir, "modeling.py"), "w") as f:
|
||||
f.write(FAKE_MODEL_CODE)
|
||||
|
||||
repo.push_to_hub()
|
||||
print(os.listdir(tmp_dir))
|
||||
|
||||
new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
Reference in New Issue
Block a user