Add an API to register objects to Auto classes (#13989)
* Add API to register a new object in auto classes * Fix test * Documentation * Add to tokenizers and test * Add cleanup after tests * Be more careful * Move import * Move import * Cleanup in TF test too * Add consistency check * Add documentation * Style * Update docs/source/model_doc/auto.rst Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/models/auto/auto_factory.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -18,7 +18,8 @@ import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers import BertConfig, is_torch_available
|
||||
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
|
||||
from transformers.testing_utils import (
|
||||
DUMMY_UNKNOWN_IDENTIFIER,
|
||||
SMALL_MODEL_IDENTIFIER,
|
||||
@@ -27,6 +28,8 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
)
|
||||
|
||||
from .test_modeling_bert import BertModelTester
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
@@ -43,7 +46,6 @@ if is_torch_available():
|
||||
AutoModelForTableQuestionAnswering,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelWithLMHead,
|
||||
BertConfig,
|
||||
BertForMaskedLM,
|
||||
BertForPreTraining,
|
||||
BertForQuestionAnswering,
|
||||
@@ -79,8 +81,15 @@ if is_torch_available():
|
||||
from transformers.models.tapas.modeling_tapas import TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
class NewModelConfig(BertConfig):
|
||||
model_type = "new-model"
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
|
||||
class NewModel(BertModel):
|
||||
config_class = NewModelConfig
|
||||
|
||||
class FakeModel(PreTrainedModel):
|
||||
config_class = BertConfig
|
||||
base_model_prefix = "fake"
|
||||
@@ -330,3 +339,53 @@ class AutoModelTest(unittest.TestCase):
|
||||
new_model = AutoModel.from_pretrained(tmp_dir, trust_remote_code=True)
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
def test_new_model_registration(self):
|
||||
AutoConfig.register("new-model", NewModelConfig)
|
||||
|
||||
auto_classes = [
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForPreTraining,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification,
|
||||
]
|
||||
|
||||
try:
|
||||
for auto_class in auto_classes:
|
||||
with self.subTest(auto_class.__name__):
|
||||
# Wrong config class will raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
auto_class.register(BertConfig, NewModel)
|
||||
auto_class.register(NewModelConfig, NewModel)
|
||||
# Trying to register something existing in the Transformers library will raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
auto_class.register(BertConfig, BertModel)
|
||||
|
||||
# Now that the config is registered, it can be used as any other config with the auto-API
|
||||
tiny_config = BertModelTester(self).get_config()
|
||||
config = NewModelConfig(**tiny_config.to_dict())
|
||||
model = auto_class.from_config(config)
|
||||
self.assertIsInstance(model, NewModel)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
new_model = auto_class.from_pretrained(tmp_dir)
|
||||
self.assertIsInstance(new_model, NewModel)
|
||||
|
||||
finally:
|
||||
if "new-model" in CONFIG_MAPPING._extra_content:
|
||||
del CONFIG_MAPPING._extra_content["new-model"]
|
||||
for mapping in (
|
||||
MODEL_MAPPING,
|
||||
MODEL_FOR_PRETRAINING_MAPPING,
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
):
|
||||
if NewModelConfig in mapping._extra_content:
|
||||
del mapping._extra_content[NewModelConfig]
|
||||
|
||||
Reference in New Issue
Block a user