Add an API to register objects to Auto classes (#13989)

* Add API to register a new object in auto classes

* Fix test

* Documentation

* Add to tokenizers and test

* Add cleanup after tests

* Be more careful

* Move import

* Move import

* Cleanup in TF test too

* Add consistency check

* Add documentation

* Style

* Update docs/source/model_doc/auto.rst

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/models/auto/auto_factory.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Sylvain Gugger
2021-10-18 10:22:46 -04:00
committed by GitHub
parent 3d587c5343
commit 2c60ff2fe2
8 changed files with 384 additions and 18 deletions

View File

@@ -24,16 +24,19 @@ from transformers import (
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
AutoTokenizer,
BertConfig,
BertTokenizer,
BertTokenizerFast,
CTRLTokenizer,
GPT2Tokenizer,
GPT2TokenizerFast,
PretrainedConfig,
PreTrainedTokenizerFast,
RobertaTokenizer,
RobertaTokenizerFast,
is_tokenizers_available,
)
from transformers.models.auto.configuration_auto import AutoConfig
from transformers.models.auto.configuration_auto import CONFIG_MAPPING, AutoConfig
from transformers.models.auto.tokenization_auto import (
TOKENIZER_MAPPING,
get_tokenizer_config,
@@ -49,6 +52,21 @@ from transformers.testing_utils import (
)
class NewConfig(PretrainedConfig):
model_type = "new-model"
class NewTokenizer(BertTokenizer):
pass
if is_tokenizers_available():
class NewTokenizerFast(BertTokenizerFast):
slow_tokenizer_class = NewTokenizer
pass
class AutoTokenizerTest(unittest.TestCase):
@slow
def test_tokenizer_from_pretrained(self):
@@ -225,3 +243,67 @@ class AutoTokenizerTest(unittest.TestCase):
self.assertEqual(config["tokenizer_class"], "BertTokenizer")
# Check other keys just to make sure the config was properly saved /reloaded.
self.assertEqual(config["name_or_path"], SMALL_MODEL_IDENTIFIER)
def test_new_tokenizer_registration(self):
try:
AutoConfig.register("new-model", NewConfig)
AutoTokenizer.register(NewConfig, slow_tokenizer_class=NewTokenizer)
# Trying to register something existing in the Transformers library will raise an error
with self.assertRaises(ValueError):
AutoTokenizer.register(BertConfig, slow_tokenizer_class=BertTokenizer)
tokenizer = NewTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer.save_pretrained(tmp_dir)
new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir)
self.assertIsInstance(new_tokenizer, NewTokenizer)
finally:
if "new-model" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["new-model"]
if NewConfig in TOKENIZER_MAPPING._extra_content:
del TOKENIZER_MAPPING._extra_content[NewConfig]
@require_tokenizers
def test_new_tokenizer_fast_registration(self):
try:
AutoConfig.register("new-model", NewConfig)
# Can register in two steps
AutoTokenizer.register(NewConfig, slow_tokenizer_class=NewTokenizer)
self.assertEqual(TOKENIZER_MAPPING[NewConfig], (NewTokenizer, None))
AutoTokenizer.register(NewConfig, fast_tokenizer_class=NewTokenizerFast)
self.assertEqual(TOKENIZER_MAPPING[NewConfig], (NewTokenizer, NewTokenizerFast))
del TOKENIZER_MAPPING._extra_content[NewConfig]
# Can register in one step
AutoTokenizer.register(NewConfig, slow_tokenizer_class=NewTokenizer, fast_tokenizer_class=NewTokenizerFast)
self.assertEqual(TOKENIZER_MAPPING[NewConfig], (NewTokenizer, NewTokenizerFast))
# Trying to register something existing in the Transformers library will raise an error
with self.assertRaises(ValueError):
AutoTokenizer.register(BertConfig, fast_tokenizer_class=BertTokenizerFast)
# We pass through a bert tokenizer fast cause there is no converter slow to fast for our new toknizer
# and that model does not have a tokenizer.json
with tempfile.TemporaryDirectory() as tmp_dir:
bert_tokenizer = BertTokenizerFast.from_pretrained(SMALL_MODEL_IDENTIFIER)
bert_tokenizer.save_pretrained(tmp_dir)
tokenizer = NewTokenizerFast.from_pretrained(tmp_dir)
with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer.save_pretrained(tmp_dir)
new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir)
self.assertIsInstance(new_tokenizer, NewTokenizerFast)
new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir, use_fast=False)
self.assertIsInstance(new_tokenizer, NewTokenizer)
finally:
if "new-model" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["new-model"]
if NewConfig in TOKENIZER_MAPPING._extra_content:
del TOKENIZER_MAPPING._extra_content[NewConfig]