Add an API to register objects to Auto classes (#13989)
* Add API to register a new object in auto classes * Fix test * Documentation * Add to tokenizers and test * Add cleanup after tests * Be more careful * Move import * Move import * Cleanup in TF test too * Add consistency check * Add documentation * Style * Update docs/source/model_doc/auto.rst Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/models/auto/auto_factory.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -27,7 +27,32 @@ Instantiating one of :class:`~transformers.AutoConfig`, :class:`~transformers.Au
|
|||||||
|
|
||||||
will create a model that is an instance of :class:`~transformers.BertModel`.
|
will create a model that is an instance of :class:`~transformers.BertModel`.
|
||||||
|
|
||||||
There is one class of :obj:`AutoModel` for each task, and for each backend (PyTorch or TensorFlow).
|
There is one class of :obj:`AutoModel` for each task, and for each backend (PyTorch, TensorFlow, or Flax).
|
||||||
|
|
||||||
|
Extending the Auto Classes
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
Each of the auto classes has a method to be extended with your custom classes. For instance, if you have defined a
|
||||||
|
custom class of model :obj:`NewModel`, make sure you have a :obj:`NewModelConfig` then you can add those to the auto
|
||||||
|
classes like this:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
from transformers import AutoConfig, AutoModel
|
||||||
|
|
||||||
|
AutoConfig.register("new-model", NewModelConfig)
|
||||||
|
AutoModel.register(NewModelConfig, NewModel)
|
||||||
|
|
||||||
|
You will then be able to use the auto classes like you would usually do!
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
If your :obj:`NewModelConfig` is a subclass of :class:`~transformer.PretrainedConfig`, make sure its
|
||||||
|
:obj:`model_type` attribute is set to the same key you use when registering the config (here :obj:`"new-model"`).
|
||||||
|
|
||||||
|
Likewise, if your :obj:`NewModel` is a subclass of :class:`~transformers.PreTrainedModel`, make sure its
|
||||||
|
:obj:`config_class` attribute is set to the same class you use when registering the model (here
|
||||||
|
:obj:`NewModelConfig`).
|
||||||
|
|
||||||
|
|
||||||
AutoConfig
|
AutoConfig
|
||||||
|
|||||||
@@ -422,6 +422,25 @@ class _BaseAutoModelClass:
|
|||||||
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, config_class, model_class):
|
||||||
|
"""
|
||||||
|
Register a new model for this class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_class (:class:`~transformers.PretrainedConfig`):
|
||||||
|
The configuration corresponding to the model to register.
|
||||||
|
model_class (:class:`~transformers.PreTrainedModel`):
|
||||||
|
The model to register.
|
||||||
|
"""
|
||||||
|
if hasattr(model_class, "config_class") and model_class.config_class != config_class:
|
||||||
|
raise ValueError(
|
||||||
|
"The model class you are passing has a `config_class` attribute that is not consistent with the "
|
||||||
|
f"config class you passed (model has {model_class.config_class} and you passed {config_class}. Fix "
|
||||||
|
"one of those so they match!"
|
||||||
|
)
|
||||||
|
cls._model_mapping.register(config_class, model_class)
|
||||||
|
|
||||||
|
|
||||||
def insert_head_doc(docstring, head_doc=""):
|
def insert_head_doc(docstring, head_doc=""):
|
||||||
if len(head_doc) > 0:
|
if len(head_doc) > 0:
|
||||||
@@ -507,9 +526,12 @@ class _LazyAutoMapping(OrderedDict):
|
|||||||
self._config_mapping = config_mapping
|
self._config_mapping = config_mapping
|
||||||
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
|
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
|
||||||
self._model_mapping = model_mapping
|
self._model_mapping = model_mapping
|
||||||
|
self._extra_content = {}
|
||||||
self._modules = {}
|
self._modules = {}
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
|
if key in self._extra_content:
|
||||||
|
return self._extra_content[key]
|
||||||
model_type = self._reverse_config_mapping[key.__name__]
|
model_type = self._reverse_config_mapping[key.__name__]
|
||||||
if model_type not in self._model_mapping:
|
if model_type not in self._model_mapping:
|
||||||
raise KeyError(key)
|
raise KeyError(key)
|
||||||
@@ -523,11 +545,12 @@ class _LazyAutoMapping(OrderedDict):
|
|||||||
return getattribute_from_module(self._modules[module_name], attr)
|
return getattribute_from_module(self._modules[module_name], attr)
|
||||||
|
|
||||||
def keys(self):
|
def keys(self):
|
||||||
return [
|
mapping_keys = [
|
||||||
self._load_attr_from_module(key, name)
|
self._load_attr_from_module(key, name)
|
||||||
for key, name in self._config_mapping.items()
|
for key, name in self._config_mapping.items()
|
||||||
if key in self._model_mapping.keys()
|
if key in self._model_mapping.keys()
|
||||||
]
|
]
|
||||||
|
return mapping_keys + list(self._extra_content.keys())
|
||||||
|
|
||||||
def get(self, key, default):
|
def get(self, key, default):
|
||||||
try:
|
try:
|
||||||
@@ -539,14 +562,15 @@ class _LazyAutoMapping(OrderedDict):
|
|||||||
return bool(self.keys())
|
return bool(self.keys())
|
||||||
|
|
||||||
def values(self):
|
def values(self):
|
||||||
return [
|
mapping_values = [
|
||||||
self._load_attr_from_module(key, name)
|
self._load_attr_from_module(key, name)
|
||||||
for key, name in self._model_mapping.items()
|
for key, name in self._model_mapping.items()
|
||||||
if key in self._config_mapping.keys()
|
if key in self._config_mapping.keys()
|
||||||
]
|
]
|
||||||
|
return mapping_values + list(self._extra_content.values())
|
||||||
|
|
||||||
def items(self):
|
def items(self):
|
||||||
return [
|
mapping_items = [
|
||||||
(
|
(
|
||||||
self._load_attr_from_module(key, self._config_mapping[key]),
|
self._load_attr_from_module(key, self._config_mapping[key]),
|
||||||
self._load_attr_from_module(key, self._model_mapping[key]),
|
self._load_attr_from_module(key, self._model_mapping[key]),
|
||||||
@@ -554,12 +578,26 @@ class _LazyAutoMapping(OrderedDict):
|
|||||||
for key in self._model_mapping.keys()
|
for key in self._model_mapping.keys()
|
||||||
if key in self._config_mapping.keys()
|
if key in self._config_mapping.keys()
|
||||||
]
|
]
|
||||||
|
return mapping_items + list(self._extra_content.items())
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter(self._model_mapping.keys())
|
return iter(self.keys())
|
||||||
|
|
||||||
def __contains__(self, item):
|
def __contains__(self, item):
|
||||||
|
if item in self._extra_content:
|
||||||
|
return True
|
||||||
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
|
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
|
||||||
return False
|
return False
|
||||||
model_type = self._reverse_config_mapping[item.__name__]
|
model_type = self._reverse_config_mapping[item.__name__]
|
||||||
return model_type in self._model_mapping
|
return model_type in self._model_mapping
|
||||||
|
|
||||||
|
def register(self, key, value):
|
||||||
|
"""
|
||||||
|
Register a new model in this mapping.
|
||||||
|
"""
|
||||||
|
if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
|
||||||
|
model_type = self._reverse_config_mapping[key.__name__]
|
||||||
|
if model_type in self._model_mapping.keys():
|
||||||
|
raise ValueError(f"'{key}' is already used by a Transformers model.")
|
||||||
|
|
||||||
|
self._extra_content[key] = value
|
||||||
|
|||||||
@@ -282,9 +282,12 @@ class _LazyConfigMapping(OrderedDict):
|
|||||||
|
|
||||||
def __init__(self, mapping):
|
def __init__(self, mapping):
|
||||||
self._mapping = mapping
|
self._mapping = mapping
|
||||||
|
self._extra_content = {}
|
||||||
self._modules = {}
|
self._modules = {}
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
|
if key in self._extra_content:
|
||||||
|
return self._extra_content[key]
|
||||||
if key not in self._mapping:
|
if key not in self._mapping:
|
||||||
raise KeyError(key)
|
raise KeyError(key)
|
||||||
value = self._mapping[key]
|
value = self._mapping[key]
|
||||||
@@ -294,19 +297,27 @@ class _LazyConfigMapping(OrderedDict):
|
|||||||
return getattr(self._modules[module_name], value)
|
return getattr(self._modules[module_name], value)
|
||||||
|
|
||||||
def keys(self):
|
def keys(self):
|
||||||
return self._mapping.keys()
|
return list(self._mapping.keys()) + list(self._extra_content.keys())
|
||||||
|
|
||||||
def values(self):
|
def values(self):
|
||||||
return [self[k] for k in self._mapping.keys()]
|
return [self[k] for k in self._mapping.keys()] + list(self._extra_content.values())
|
||||||
|
|
||||||
def items(self):
|
def items(self):
|
||||||
return [(k, self[k]) for k in self._mapping.keys()]
|
return [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items())
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter(self._mapping.keys())
|
return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
|
||||||
|
|
||||||
def __contains__(self, item):
|
def __contains__(self, item):
|
||||||
return item in self._mapping
|
return item in self._mapping or item in self._extra_content
|
||||||
|
|
||||||
|
def register(self, key, value):
|
||||||
|
"""
|
||||||
|
Register a new configuration in this mapping.
|
||||||
|
"""
|
||||||
|
if key in self._mapping.keys():
|
||||||
|
raise ValueError(f"'{key}' is already used by a Transformers config, pick another name.")
|
||||||
|
self._extra_content[key] = value
|
||||||
|
|
||||||
|
|
||||||
CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
|
CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
|
||||||
@@ -550,3 +561,20 @@ class AutoConfig:
|
|||||||
f"Should have a `model_type` key in its {CONFIG_NAME}, or contain one of the following strings "
|
f"Should have a `model_type` key in its {CONFIG_NAME}, or contain one of the following strings "
|
||||||
f"in its name: {', '.join(CONFIG_MAPPING.keys())}"
|
f"in its name: {', '.join(CONFIG_MAPPING.keys())}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def register(model_type, config):
|
||||||
|
"""
|
||||||
|
Register a new configuration for this class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type (:obj:`str`): The model type like "bert" or "gpt".
|
||||||
|
config (:class:`~transformers.PretrainedConfig`): The config to register.
|
||||||
|
"""
|
||||||
|
if issubclass(config, PretrainedConfig) and config.model_type != model_type:
|
||||||
|
raise ValueError(
|
||||||
|
"The config you are passing has a `model_type` attribute that is not consistent with the model type "
|
||||||
|
f"you passed (config has {config.model_type} and you passed {model_type}. Fix one of those so they "
|
||||||
|
"match!"
|
||||||
|
)
|
||||||
|
CONFIG_MAPPING.register(model_type, config)
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from ...file_utils import (
|
|||||||
is_sentencepiece_available,
|
is_sentencepiece_available,
|
||||||
is_tokenizers_available,
|
is_tokenizers_available,
|
||||||
)
|
)
|
||||||
|
from ...tokenization_utils import PreTrainedTokenizer
|
||||||
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
||||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
@@ -237,6 +238,11 @@ def tokenizer_class_from_name(class_name: str):
|
|||||||
module = importlib.import_module(f".{module_name}", "transformers.models")
|
module = importlib.import_module(f".{module_name}", "transformers.models")
|
||||||
return getattr(module, class_name)
|
return getattr(module, class_name)
|
||||||
|
|
||||||
|
for config, tokenizers in TOKENIZER_MAPPING._extra_content.items():
|
||||||
|
for tokenizer in tokenizers:
|
||||||
|
if getattr(tokenizer, "__name__", None) == class_name:
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -510,3 +516,46 @@ class AutoTokenizer:
|
|||||||
f"Unrecognized configuration class {config.__class__} to build an AutoTokenizer.\n"
|
f"Unrecognized configuration class {config.__class__} to build an AutoTokenizer.\n"
|
||||||
f"Model type should be one of {', '.join(c.__name__ for c in TOKENIZER_MAPPING.keys())}."
|
f"Model type should be one of {', '.join(c.__name__ for c in TOKENIZER_MAPPING.keys())}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def register(config_class, slow_tokenizer_class=None, fast_tokenizer_class=None):
|
||||||
|
"""
|
||||||
|
Register a new tokenizer in this mapping.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_class (:class:`~transformers.PretrainedConfig`):
|
||||||
|
The configuration corresponding to the model to register.
|
||||||
|
slow_tokenizer_class (:class:`~transformers.PretrainedTokenizer`, `optional`):
|
||||||
|
The slow tokenizer to register.
|
||||||
|
slow_tokenizer_class (:class:`~transformers.PretrainedTokenizerFast`, `optional`):
|
||||||
|
The fast tokenizer to register.
|
||||||
|
"""
|
||||||
|
if slow_tokenizer_class is None and fast_tokenizer_class is None:
|
||||||
|
raise ValueError("You need to pass either a `slow_tokenizer_class` or a `fast_tokenizer_class")
|
||||||
|
if slow_tokenizer_class is not None and issubclass(slow_tokenizer_class, PreTrainedTokenizerFast):
|
||||||
|
raise ValueError("You passed a fast tokenizer in the `slow_tokenizer_class`.")
|
||||||
|
if fast_tokenizer_class is not None and issubclass(fast_tokenizer_class, PreTrainedTokenizer):
|
||||||
|
raise ValueError("You passed a slow tokenizer in the `fast_tokenizer_class`.")
|
||||||
|
|
||||||
|
if (
|
||||||
|
slow_tokenizer_class is not None
|
||||||
|
and fast_tokenizer_class is not None
|
||||||
|
and issubclass(fast_tokenizer_class, PreTrainedTokenizerFast)
|
||||||
|
and fast_tokenizer_class.slow_tokenizer_class != slow_tokenizer_class
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"The fast tokenizer class you are passing has a `slow_tokenizer_class` attribute that is not "
|
||||||
|
"consistent with the slow tokenizer class you passed (fast tokenizer has "
|
||||||
|
f"{fast_tokenizer_class.slow_tokenizer_class} and you passed {slow_tokenizer_class}. Fix one of those "
|
||||||
|
"so they match!"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Avoid resetting a set slow/fast tokenizer if we are passing just the other ones.
|
||||||
|
if config_class in TOKENIZER_MAPPING._extra_content:
|
||||||
|
existing_slow, existing_fast = TOKENIZER_MAPPING[config_class]
|
||||||
|
if slow_tokenizer_class is None:
|
||||||
|
slow_tokenizer_class = existing_slow
|
||||||
|
if fast_tokenizer_class is None:
|
||||||
|
fast_tokenizer_class = existing_fast
|
||||||
|
|
||||||
|
TOKENIZER_MAPPING.register(config_class, (slow_tokenizer_class, fast_tokenizer_class))
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers.models.auto.configuration_auto import CONFIG_MAPPING, AutoConfig
|
from transformers.models.auto.configuration_auto import CONFIG_MAPPING, AutoConfig
|
||||||
@@ -25,6 +26,10 @@ from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
|
|||||||
SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
|
SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
|
||||||
|
|
||||||
|
|
||||||
|
class NewModelConfig(BertConfig):
|
||||||
|
model_type = "new-model"
|
||||||
|
|
||||||
|
|
||||||
class AutoConfigTest(unittest.TestCase):
|
class AutoConfigTest(unittest.TestCase):
|
||||||
def test_config_from_model_shortcut(self):
|
def test_config_from_model_shortcut(self):
|
||||||
config = AutoConfig.from_pretrained("bert-base-uncased")
|
config = AutoConfig.from_pretrained("bert-base-uncased")
|
||||||
@@ -51,3 +56,24 @@ class AutoConfigTest(unittest.TestCase):
|
|||||||
keys = list(CONFIG_MAPPING.keys())
|
keys = list(CONFIG_MAPPING.keys())
|
||||||
for i, key in enumerate(keys):
|
for i, key in enumerate(keys):
|
||||||
self.assertFalse(any(key in later_key for later_key in keys[i + 1 :]))
|
self.assertFalse(any(key in later_key for later_key in keys[i + 1 :]))
|
||||||
|
|
||||||
|
def test_new_config_registration(self):
|
||||||
|
try:
|
||||||
|
AutoConfig.register("new-model", NewModelConfig)
|
||||||
|
# Wrong model type will raise an error
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
AutoConfig.register("model", NewModelConfig)
|
||||||
|
# Trying to register something existing in the Transformers library will raise an error
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
AutoConfig.register("bert", BertConfig)
|
||||||
|
|
||||||
|
# Now that the config is registered, it can be used as any other config with the auto-API
|
||||||
|
config = NewModelConfig()
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
config.save_pretrained(tmp_dir)
|
||||||
|
new_config = AutoConfig.from_pretrained(tmp_dir)
|
||||||
|
self.assertIsInstance(new_config, NewModelConfig)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if "new-model" in CONFIG_MAPPING._extra_content:
|
||||||
|
del CONFIG_MAPPING._extra_content["new-model"]
|
||||||
|
|||||||
@@ -18,7 +18,8 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import BertConfig, is_torch_available
|
||||||
|
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
DUMMY_UNKNOWN_IDENTIFIER,
|
DUMMY_UNKNOWN_IDENTIFIER,
|
||||||
SMALL_MODEL_IDENTIFIER,
|
SMALL_MODEL_IDENTIFIER,
|
||||||
@@ -27,6 +28,8 @@ from transformers.testing_utils import (
|
|||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .test_modeling_bert import BertModelTester
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
@@ -43,7 +46,6 @@ if is_torch_available():
|
|||||||
AutoModelForTableQuestionAnswering,
|
AutoModelForTableQuestionAnswering,
|
||||||
AutoModelForTokenClassification,
|
AutoModelForTokenClassification,
|
||||||
AutoModelWithLMHead,
|
AutoModelWithLMHead,
|
||||||
BertConfig,
|
|
||||||
BertForMaskedLM,
|
BertForMaskedLM,
|
||||||
BertForPreTraining,
|
BertForPreTraining,
|
||||||
BertForQuestionAnswering,
|
BertForQuestionAnswering,
|
||||||
@@ -79,8 +81,15 @@ if is_torch_available():
|
|||||||
from transformers.models.tapas.modeling_tapas import TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST
|
from transformers.models.tapas.modeling_tapas import TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
|
||||||
|
|
||||||
|
class NewModelConfig(BertConfig):
|
||||||
|
model_type = "new-model"
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
|
||||||
|
class NewModel(BertModel):
|
||||||
|
config_class = NewModelConfig
|
||||||
|
|
||||||
class FakeModel(PreTrainedModel):
|
class FakeModel(PreTrainedModel):
|
||||||
config_class = BertConfig
|
config_class = BertConfig
|
||||||
base_model_prefix = "fake"
|
base_model_prefix = "fake"
|
||||||
@@ -330,3 +339,53 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
new_model = AutoModel.from_pretrained(tmp_dir, trust_remote_code=True)
|
new_model = AutoModel.from_pretrained(tmp_dir, trust_remote_code=True)
|
||||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
self.assertTrue(torch.equal(p1, p2))
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
|
def test_new_model_registration(self):
|
||||||
|
AutoConfig.register("new-model", NewModelConfig)
|
||||||
|
|
||||||
|
auto_classes = [
|
||||||
|
AutoModel,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoModelForMaskedLM,
|
||||||
|
AutoModelForPreTraining,
|
||||||
|
AutoModelForQuestionAnswering,
|
||||||
|
AutoModelForSequenceClassification,
|
||||||
|
AutoModelForTokenClassification,
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
for auto_class in auto_classes:
|
||||||
|
with self.subTest(auto_class.__name__):
|
||||||
|
# Wrong config class will raise an error
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
auto_class.register(BertConfig, NewModel)
|
||||||
|
auto_class.register(NewModelConfig, NewModel)
|
||||||
|
# Trying to register something existing in the Transformers library will raise an error
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
auto_class.register(BertConfig, BertModel)
|
||||||
|
|
||||||
|
# Now that the config is registered, it can be used as any other config with the auto-API
|
||||||
|
tiny_config = BertModelTester(self).get_config()
|
||||||
|
config = NewModelConfig(**tiny_config.to_dict())
|
||||||
|
model = auto_class.from_config(config)
|
||||||
|
self.assertIsInstance(model, NewModel)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
new_model = auto_class.from_pretrained(tmp_dir)
|
||||||
|
self.assertIsInstance(new_model, NewModel)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if "new-model" in CONFIG_MAPPING._extra_content:
|
||||||
|
del CONFIG_MAPPING._extra_content["new-model"]
|
||||||
|
for mapping in (
|
||||||
|
MODEL_MAPPING,
|
||||||
|
MODEL_FOR_PRETRAINING_MAPPING,
|
||||||
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
|
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
|
MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
|
):
|
||||||
|
if NewModelConfig in mapping._extra_content:
|
||||||
|
del mapping._extra_content[NewModelConfig]
|
||||||
|
|||||||
@@ -17,16 +17,14 @@ import copy
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import is_tf_available
|
from transformers import CONFIG_MAPPING, AutoConfig, BertConfig, GPT2Config, T5Config, is_tf_available
|
||||||
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_tf, slow
|
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_tf, slow
|
||||||
|
|
||||||
|
from .test_modeling_bert import BertModelTester
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
|
||||||
BertConfig,
|
|
||||||
GPT2Config,
|
|
||||||
T5Config,
|
|
||||||
TFAutoModel,
|
TFAutoModel,
|
||||||
TFAutoModelForCausalLM,
|
TFAutoModelForCausalLM,
|
||||||
TFAutoModelForMaskedLM,
|
TFAutoModelForMaskedLM,
|
||||||
@@ -34,6 +32,7 @@ if is_tf_available():
|
|||||||
TFAutoModelForQuestionAnswering,
|
TFAutoModelForQuestionAnswering,
|
||||||
TFAutoModelForSeq2SeqLM,
|
TFAutoModelForSeq2SeqLM,
|
||||||
TFAutoModelForSequenceClassification,
|
TFAutoModelForSequenceClassification,
|
||||||
|
TFAutoModelForTokenClassification,
|
||||||
TFAutoModelWithLMHead,
|
TFAutoModelWithLMHead,
|
||||||
TFBertForMaskedLM,
|
TFBertForMaskedLM,
|
||||||
TFBertForPreTraining,
|
TFBertForPreTraining,
|
||||||
@@ -62,6 +61,16 @@ if is_tf_available():
|
|||||||
from transformers.models.t5.modeling_tf_t5 import TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST
|
from transformers.models.t5.modeling_tf_t5 import TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
|
||||||
|
|
||||||
|
class NewModelConfig(BertConfig):
|
||||||
|
model_type = "new-model"
|
||||||
|
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
|
||||||
|
class TFNewModel(TFBertModel):
|
||||||
|
config_class = NewModelConfig
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFAutoModelTest(unittest.TestCase):
|
class TFAutoModelTest(unittest.TestCase):
|
||||||
@slow
|
@slow
|
||||||
@@ -224,3 +233,53 @@ class TFAutoModelTest(unittest.TestCase):
|
|||||||
|
|
||||||
for child, parent in [(a, b) for a in child_model for b in parent_model]:
|
for child, parent in [(a, b) for a in child_model for b in parent_model]:
|
||||||
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"
|
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"
|
||||||
|
|
||||||
|
def test_new_model_registration(self):
|
||||||
|
try:
|
||||||
|
AutoConfig.register("new-model", NewModelConfig)
|
||||||
|
|
||||||
|
auto_classes = [
|
||||||
|
TFAutoModel,
|
||||||
|
TFAutoModelForCausalLM,
|
||||||
|
TFAutoModelForMaskedLM,
|
||||||
|
TFAutoModelForPreTraining,
|
||||||
|
TFAutoModelForQuestionAnswering,
|
||||||
|
TFAutoModelForSequenceClassification,
|
||||||
|
TFAutoModelForTokenClassification,
|
||||||
|
]
|
||||||
|
|
||||||
|
for auto_class in auto_classes:
|
||||||
|
with self.subTest(auto_class.__name__):
|
||||||
|
# Wrong config class will raise an error
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
auto_class.register(BertConfig, TFNewModel)
|
||||||
|
auto_class.register(NewModelConfig, TFNewModel)
|
||||||
|
# Trying to register something existing in the Transformers library will raise an error
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
auto_class.register(BertConfig, TFBertModel)
|
||||||
|
|
||||||
|
# Now that the config is registered, it can be used as any other config with the auto-API
|
||||||
|
tiny_config = BertModelTester(self).get_config()
|
||||||
|
config = NewModelConfig(**tiny_config.to_dict())
|
||||||
|
model = auto_class.from_config(config)
|
||||||
|
self.assertIsInstance(model, TFNewModel)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
new_model = auto_class.from_pretrained(tmp_dir)
|
||||||
|
self.assertIsInstance(new_model, TFNewModel)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if "new-model" in CONFIG_MAPPING._extra_content:
|
||||||
|
del CONFIG_MAPPING._extra_content["new-model"]
|
||||||
|
for mapping in (
|
||||||
|
TF_MODEL_MAPPING,
|
||||||
|
TF_MODEL_FOR_PRETRAINING_MAPPING,
|
||||||
|
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
|
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
|
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
|
TF_MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
|
):
|
||||||
|
if NewModelConfig in mapping._extra_content:
|
||||||
|
del mapping._extra_content[NewModelConfig]
|
||||||
|
|||||||
@@ -24,16 +24,19 @@ from transformers import (
|
|||||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
BertConfig,
|
||||||
BertTokenizer,
|
BertTokenizer,
|
||||||
BertTokenizerFast,
|
BertTokenizerFast,
|
||||||
CTRLTokenizer,
|
CTRLTokenizer,
|
||||||
GPT2Tokenizer,
|
GPT2Tokenizer,
|
||||||
GPT2TokenizerFast,
|
GPT2TokenizerFast,
|
||||||
|
PretrainedConfig,
|
||||||
PreTrainedTokenizerFast,
|
PreTrainedTokenizerFast,
|
||||||
RobertaTokenizer,
|
RobertaTokenizer,
|
||||||
RobertaTokenizerFast,
|
RobertaTokenizerFast,
|
||||||
|
is_tokenizers_available,
|
||||||
)
|
)
|
||||||
from transformers.models.auto.configuration_auto import AutoConfig
|
from transformers.models.auto.configuration_auto import CONFIG_MAPPING, AutoConfig
|
||||||
from transformers.models.auto.tokenization_auto import (
|
from transformers.models.auto.tokenization_auto import (
|
||||||
TOKENIZER_MAPPING,
|
TOKENIZER_MAPPING,
|
||||||
get_tokenizer_config,
|
get_tokenizer_config,
|
||||||
@@ -49,6 +52,21 @@ from transformers.testing_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NewConfig(PretrainedConfig):
|
||||||
|
model_type = "new-model"
|
||||||
|
|
||||||
|
|
||||||
|
class NewTokenizer(BertTokenizer):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if is_tokenizers_available():
|
||||||
|
|
||||||
|
class NewTokenizerFast(BertTokenizerFast):
|
||||||
|
slow_tokenizer_class = NewTokenizer
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AutoTokenizerTest(unittest.TestCase):
|
class AutoTokenizerTest(unittest.TestCase):
|
||||||
@slow
|
@slow
|
||||||
def test_tokenizer_from_pretrained(self):
|
def test_tokenizer_from_pretrained(self):
|
||||||
@@ -225,3 +243,67 @@ class AutoTokenizerTest(unittest.TestCase):
|
|||||||
self.assertEqual(config["tokenizer_class"], "BertTokenizer")
|
self.assertEqual(config["tokenizer_class"], "BertTokenizer")
|
||||||
# Check other keys just to make sure the config was properly saved /reloaded.
|
# Check other keys just to make sure the config was properly saved /reloaded.
|
||||||
self.assertEqual(config["name_or_path"], SMALL_MODEL_IDENTIFIER)
|
self.assertEqual(config["name_or_path"], SMALL_MODEL_IDENTIFIER)
|
||||||
|
|
||||||
|
def test_new_tokenizer_registration(self):
|
||||||
|
try:
|
||||||
|
AutoConfig.register("new-model", NewConfig)
|
||||||
|
|
||||||
|
AutoTokenizer.register(NewConfig, slow_tokenizer_class=NewTokenizer)
|
||||||
|
# Trying to register something existing in the Transformers library will raise an error
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
AutoTokenizer.register(BertConfig, slow_tokenizer_class=BertTokenizer)
|
||||||
|
|
||||||
|
tokenizer = NewTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
tokenizer.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir)
|
||||||
|
self.assertIsInstance(new_tokenizer, NewTokenizer)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if "new-model" in CONFIG_MAPPING._extra_content:
|
||||||
|
del CONFIG_MAPPING._extra_content["new-model"]
|
||||||
|
if NewConfig in TOKENIZER_MAPPING._extra_content:
|
||||||
|
del TOKENIZER_MAPPING._extra_content[NewConfig]
|
||||||
|
|
||||||
|
@require_tokenizers
|
||||||
|
def test_new_tokenizer_fast_registration(self):
|
||||||
|
try:
|
||||||
|
AutoConfig.register("new-model", NewConfig)
|
||||||
|
|
||||||
|
# Can register in two steps
|
||||||
|
AutoTokenizer.register(NewConfig, slow_tokenizer_class=NewTokenizer)
|
||||||
|
self.assertEqual(TOKENIZER_MAPPING[NewConfig], (NewTokenizer, None))
|
||||||
|
AutoTokenizer.register(NewConfig, fast_tokenizer_class=NewTokenizerFast)
|
||||||
|
self.assertEqual(TOKENIZER_MAPPING[NewConfig], (NewTokenizer, NewTokenizerFast))
|
||||||
|
|
||||||
|
del TOKENIZER_MAPPING._extra_content[NewConfig]
|
||||||
|
# Can register in one step
|
||||||
|
AutoTokenizer.register(NewConfig, slow_tokenizer_class=NewTokenizer, fast_tokenizer_class=NewTokenizerFast)
|
||||||
|
self.assertEqual(TOKENIZER_MAPPING[NewConfig], (NewTokenizer, NewTokenizerFast))
|
||||||
|
|
||||||
|
# Trying to register something existing in the Transformers library will raise an error
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
AutoTokenizer.register(BertConfig, fast_tokenizer_class=BertTokenizerFast)
|
||||||
|
|
||||||
|
# We pass through a bert tokenizer fast cause there is no converter slow to fast for our new toknizer
|
||||||
|
# and that model does not have a tokenizer.json
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
bert_tokenizer = BertTokenizerFast.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||||
|
bert_tokenizer.save_pretrained(tmp_dir)
|
||||||
|
tokenizer = NewTokenizerFast.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
tokenizer.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir)
|
||||||
|
self.assertIsInstance(new_tokenizer, NewTokenizerFast)
|
||||||
|
|
||||||
|
new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir, use_fast=False)
|
||||||
|
self.assertIsInstance(new_tokenizer, NewTokenizer)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if "new-model" in CONFIG_MAPPING._extra_content:
|
||||||
|
del CONFIG_MAPPING._extra_content["new-model"]
|
||||||
|
if NewConfig in TOKENIZER_MAPPING._extra_content:
|
||||||
|
del TOKENIZER_MAPPING._extra_content[NewConfig]
|
||||||
|
|||||||
Reference in New Issue
Block a user