Save code of registered custom models (#15379)

* Allow dynamic modules to use relative imports

* Work for configs

* Fix last merge conflict

* Save code of registered custom objects

* Map strings to strings

* Fix test

* Add tokenizer

* Rework tests

* Tests

* Ignore fixtures py files for tests

* Tokenizer test + fix collection

* With full path

* Rework integration

* Fix typo

* Remove changes in conftest

* Test for tokenizers

* Add documentation

* Update docs/source/custom_models.mdx

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Add file structure and file content

* Add more doc

* Style

* Update docs/source/custom_models.mdx

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Address review comments

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Sylvain Gugger
2022-02-02 10:44:37 -05:00
committed by GitHub
parent 623d8cb475
commit 44b21f117b
23 changed files with 630 additions and 295 deletions

View File

@@ -17,9 +17,11 @@ import copy
import json
import os
import shutil
import sys
import tempfile
import unittest
import unittest.mock
from pathlib import Path
from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError
@@ -28,6 +30,11 @@ from transformers.configuration_utils import PretrainedConfig
from transformers.testing_utils import PASS, USER, is_staging_test
sys.path.append(str(Path(__file__).parent.parent / "utils"))
from test_module.custom_configuration import CustomConfig # noqa E402
config_common_kwargs = {
"return_dict": False,
"output_hidden_states": True,
@@ -192,23 +199,6 @@ class ConfigTester(object):
self.check_config_arguments_init()
class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)
# Make sure this is synchronized with the config above.
FAKE_CONFIG_CODE = """
from transformers import PretrainedConfig
class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)
"""
@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):
@classmethod
@@ -263,20 +253,23 @@ class ConfigPushToHubTester(unittest.TestCase):
self.assertEqual(v, getattr(new_config, k))
def test_push_to_hub_dynamic_config(self):
config = FakeConfig(attribute=42)
config.auto_map = {"AutoConfig": "configuration.FakeConfig"}
CustomConfig.register_for_auto_class()
config = CustomConfig(attribute=42)
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-config", use_auth_token=self._token)
config.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "configuration.py"), "w") as f:
f.write(FAKE_CONFIG_CODE)
# This has added the proper auto_map field to the config
self.assertDictEqual(config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"})
# The code has been copied from fixtures
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "custom_configuration.py")))
repo.push_to_hub()
new_config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-config", trust_remote_code=True)
# Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
self.assertEqual(new_config.__class__.__name__, "FakeConfig")
self.assertEqual(new_config.__class__.__name__, "CustomConfig")
self.assertEqual(new_config.attribute, 42)