Add an API to register objects to Auto classes (#13989)
* Add API to register a new object in auto classes * Fix test * Documentation * Add to tokenizers and test * Add cleanup after tests * Be more careful * Move import * Move import * Cleanup in TF test too * Add consistency check * Add documentation * Style * Update docs/source/model_doc/auto.rst Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/models/auto/auto_factory.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -17,16 +17,14 @@ import copy
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import is_tf_available
|
||||
from transformers import CONFIG_MAPPING, AutoConfig, BertConfig, GPT2Config, T5Config, is_tf_available
|
||||
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_tf, slow
|
||||
|
||||
from .test_modeling_bert import BertModelTester
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
BertConfig,
|
||||
GPT2Config,
|
||||
T5Config,
|
||||
TFAutoModel,
|
||||
TFAutoModelForCausalLM,
|
||||
TFAutoModelForMaskedLM,
|
||||
@@ -34,6 +32,7 @@ if is_tf_available():
|
||||
TFAutoModelForQuestionAnswering,
|
||||
TFAutoModelForSeq2SeqLM,
|
||||
TFAutoModelForSequenceClassification,
|
||||
TFAutoModelForTokenClassification,
|
||||
TFAutoModelWithLMHead,
|
||||
TFBertForMaskedLM,
|
||||
TFBertForPreTraining,
|
||||
@@ -62,6 +61,16 @@ if is_tf_available():
|
||||
from transformers.models.t5.modeling_tf_t5 import TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
class NewModelConfig(BertConfig):
|
||||
model_type = "new-model"
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
||||
class TFNewModel(TFBertModel):
|
||||
config_class = NewModelConfig
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFAutoModelTest(unittest.TestCase):
|
||||
@slow
|
||||
@@ -224,3 +233,53 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
|
||||
for child, parent in [(a, b) for a in child_model for b in parent_model]:
|
||||
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"
|
||||
|
||||
def test_new_model_registration(self):
|
||||
try:
|
||||
AutoConfig.register("new-model", NewModelConfig)
|
||||
|
||||
auto_classes = [
|
||||
TFAutoModel,
|
||||
TFAutoModelForCausalLM,
|
||||
TFAutoModelForMaskedLM,
|
||||
TFAutoModelForPreTraining,
|
||||
TFAutoModelForQuestionAnswering,
|
||||
TFAutoModelForSequenceClassification,
|
||||
TFAutoModelForTokenClassification,
|
||||
]
|
||||
|
||||
for auto_class in auto_classes:
|
||||
with self.subTest(auto_class.__name__):
|
||||
# Wrong config class will raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
auto_class.register(BertConfig, TFNewModel)
|
||||
auto_class.register(NewModelConfig, TFNewModel)
|
||||
# Trying to register something existing in the Transformers library will raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
auto_class.register(BertConfig, TFBertModel)
|
||||
|
||||
# Now that the config is registered, it can be used as any other config with the auto-API
|
||||
tiny_config = BertModelTester(self).get_config()
|
||||
config = NewModelConfig(**tiny_config.to_dict())
|
||||
model = auto_class.from_config(config)
|
||||
self.assertIsInstance(model, TFNewModel)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
new_model = auto_class.from_pretrained(tmp_dir)
|
||||
self.assertIsInstance(new_model, TFNewModel)
|
||||
|
||||
finally:
|
||||
if "new-model" in CONFIG_MAPPING._extra_content:
|
||||
del CONFIG_MAPPING._extra_content["new-model"]
|
||||
for mapping in (
|
||||
TF_MODEL_MAPPING,
|
||||
TF_MODEL_FOR_PRETRAINING_MAPPING,
|
||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_MASKED_LM_MAPPING,
|
||||
):
|
||||
if NewModelConfig in mapping._extra_content:
|
||||
del mapping._extra_content[NewModelConfig]
|
||||
|
||||
Reference in New Issue
Block a user