Add an API to register objects to Auto classes (#13989)

* Add API to register a new object in auto classes

* Fix test

* Documentation

* Add to tokenizers and test

* Add cleanup after tests

* Be more careful

* Move import

* Move import

* Cleanup in TF test too

* Add consistency check

* Add documentation

* Style

* Update docs/source/model_doc/auto.rst

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/models/auto/auto_factory.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Sylvain Gugger
2021-10-18 10:22:46 -04:00
committed by GitHub
parent 3d587c5343
commit 2c60ff2fe2
8 changed files with 384 additions and 18 deletions

View File

@@ -18,7 +18,8 @@ import os
import tempfile
import unittest
from transformers import is_torch_available
from transformers import BertConfig, is_torch_available
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.testing_utils import (
DUMMY_UNKNOWN_IDENTIFIER,
SMALL_MODEL_IDENTIFIER,
@@ -27,6 +28,8 @@ from transformers.testing_utils import (
slow,
)
from .test_modeling_bert import BertModelTester
if is_torch_available():
import torch
@@ -43,7 +46,6 @@ if is_torch_available():
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification,
AutoModelWithLMHead,
BertConfig,
BertForMaskedLM,
BertForPreTraining,
BertForQuestionAnswering,
@@ -79,8 +81,15 @@ if is_torch_available():
from transformers.models.tapas.modeling_tapas import TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST
class NewModelConfig(BertConfig):
model_type = "new-model"
if is_torch_available():
class NewModel(BertModel):
config_class = NewModelConfig
class FakeModel(PreTrainedModel):
config_class = BertConfig
base_model_prefix = "fake"
@@ -330,3 +339,53 @@ class AutoModelTest(unittest.TestCase):
new_model = AutoModel.from_pretrained(tmp_dir, trust_remote_code=True)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
def test_new_model_registration(self):
AutoConfig.register("new-model", NewModelConfig)
auto_classes = [
AutoModel,
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoModelForPreTraining,
AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
]
try:
for auto_class in auto_classes:
with self.subTest(auto_class.__name__):
# Wrong config class will raise an error
with self.assertRaises(ValueError):
auto_class.register(BertConfig, NewModel)
auto_class.register(NewModelConfig, NewModel)
# Trying to register something existing in the Transformers library will raise an error
with self.assertRaises(ValueError):
auto_class.register(BertConfig, BertModel)
# Now that the config is registered, it can be used as any other config with the auto-API
tiny_config = BertModelTester(self).get_config()
config = NewModelConfig(**tiny_config.to_dict())
model = auto_class.from_config(config)
self.assertIsInstance(model, NewModel)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
new_model = auto_class.from_pretrained(tmp_dir)
self.assertIsInstance(new_model, NewModel)
finally:
if "new-model" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["new-model"]
for mapping in (
MODEL_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
):
if NewModelConfig in mapping._extra_content:
del mapping._extra_content[NewModelConfig]