Add an API to register objects to Auto classes (#13989)

* Add API to register a new object in auto classes

* Fix test

* Documentation

* Add to tokenizers and test

* Add cleanup after tests

* Be more careful

* Move import

* Move import

* Cleanup in TF test too

* Add consistency check

* Add documentation

* Style

* Update docs/source/model_doc/auto.rst

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/models/auto/auto_factory.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Sylvain Gugger
2021-10-18 10:22:46 -04:00
committed by GitHub
parent 3d587c5343
commit 2c60ff2fe2
8 changed files with 384 additions and 18 deletions

View File

@@ -17,16 +17,14 @@ import copy
import tempfile
import unittest
from transformers import is_tf_available
from transformers import CONFIG_MAPPING, AutoConfig, BertConfig, GPT2Config, T5Config, is_tf_available
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_tf, slow
from .test_modeling_bert import BertModelTester
if is_tf_available():
from transformers import (
AutoConfig,
BertConfig,
GPT2Config,
T5Config,
TFAutoModel,
TFAutoModelForCausalLM,
TFAutoModelForMaskedLM,
@@ -34,6 +32,7 @@ if is_tf_available():
TFAutoModelForQuestionAnswering,
TFAutoModelForSeq2SeqLM,
TFAutoModelForSequenceClassification,
TFAutoModelForTokenClassification,
TFAutoModelWithLMHead,
TFBertForMaskedLM,
TFBertForPreTraining,
@@ -62,6 +61,16 @@ if is_tf_available():
from transformers.models.t5.modeling_tf_t5 import TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST
class NewModelConfig(BertConfig):
model_type = "new-model"
if is_tf_available():
class TFNewModel(TFBertModel):
config_class = NewModelConfig
@require_tf
class TFAutoModelTest(unittest.TestCase):
@slow
@@ -224,3 +233,53 @@ class TFAutoModelTest(unittest.TestCase):
for child, parent in [(a, b) for a in child_model for b in parent_model]:
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"
def test_new_model_registration(self):
try:
AutoConfig.register("new-model", NewModelConfig)
auto_classes = [
TFAutoModel,
TFAutoModelForCausalLM,
TFAutoModelForMaskedLM,
TFAutoModelForPreTraining,
TFAutoModelForQuestionAnswering,
TFAutoModelForSequenceClassification,
TFAutoModelForTokenClassification,
]
for auto_class in auto_classes:
with self.subTest(auto_class.__name__):
# Wrong config class will raise an error
with self.assertRaises(ValueError):
auto_class.register(BertConfig, TFNewModel)
auto_class.register(NewModelConfig, TFNewModel)
# Trying to register something existing in the Transformers library will raise an error
with self.assertRaises(ValueError):
auto_class.register(BertConfig, TFBertModel)
# Now that the config is registered, it can be used as any other config with the auto-API
tiny_config = BertModelTester(self).get_config()
config = NewModelConfig(**tiny_config.to_dict())
model = auto_class.from_config(config)
self.assertIsInstance(model, TFNewModel)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
new_model = auto_class.from_pretrained(tmp_dir)
self.assertIsInstance(new_model, TFNewModel)
finally:
if "new-model" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["new-model"]
for mapping in (
TF_MODEL_MAPPING,
TF_MODEL_FOR_PRETRAINING_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_MASKED_LM_MAPPING,
):
if NewModelConfig in mapping._extra_content:
del mapping._extra_content[NewModelConfig]