Add an API to register objects to Auto classes (#13989)
* Add API to register a new object in auto classes * Fix test * Documentation * Add to tokenizers and test * Add cleanup after tests * Be more careful * Move import * Move import * Cleanup in TF test too * Add consistency check * Add documentation * Style * Update docs/source/model_doc/auto.rst Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/models/auto/auto_factory.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -24,16 +24,19 @@ from transformers import (
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
AutoTokenizer,
|
||||
BertConfig,
|
||||
BertTokenizer,
|
||||
BertTokenizerFast,
|
||||
CTRLTokenizer,
|
||||
GPT2Tokenizer,
|
||||
GPT2TokenizerFast,
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizerFast,
|
||||
RobertaTokenizer,
|
||||
RobertaTokenizerFast,
|
||||
is_tokenizers_available,
|
||||
)
|
||||
from transformers.models.auto.configuration_auto import AutoConfig
|
||||
from transformers.models.auto.configuration_auto import CONFIG_MAPPING, AutoConfig
|
||||
from transformers.models.auto.tokenization_auto import (
|
||||
TOKENIZER_MAPPING,
|
||||
get_tokenizer_config,
|
||||
@@ -49,6 +52,21 @@ from transformers.testing_utils import (
|
||||
)
|
||||
|
||||
|
||||
class NewConfig(PretrainedConfig):
|
||||
model_type = "new-model"
|
||||
|
||||
|
||||
class NewTokenizer(BertTokenizer):
|
||||
pass
|
||||
|
||||
|
||||
if is_tokenizers_available():
|
||||
|
||||
class NewTokenizerFast(BertTokenizerFast):
|
||||
slow_tokenizer_class = NewTokenizer
|
||||
pass
|
||||
|
||||
|
||||
class AutoTokenizerTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_tokenizer_from_pretrained(self):
|
||||
@@ -225,3 +243,67 @@ class AutoTokenizerTest(unittest.TestCase):
|
||||
self.assertEqual(config["tokenizer_class"], "BertTokenizer")
|
||||
# Check other keys just to make sure the config was properly saved /reloaded.
|
||||
self.assertEqual(config["name_or_path"], SMALL_MODEL_IDENTIFIER)
|
||||
|
||||
def test_new_tokenizer_registration(self):
|
||||
try:
|
||||
AutoConfig.register("new-model", NewConfig)
|
||||
|
||||
AutoTokenizer.register(NewConfig, slow_tokenizer_class=NewTokenizer)
|
||||
# Trying to register something existing in the Transformers library will raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
AutoTokenizer.register(BertConfig, slow_tokenizer_class=BertTokenizer)
|
||||
|
||||
tokenizer = NewTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tokenizer.save_pretrained(tmp_dir)
|
||||
|
||||
new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir)
|
||||
self.assertIsInstance(new_tokenizer, NewTokenizer)
|
||||
|
||||
finally:
|
||||
if "new-model" in CONFIG_MAPPING._extra_content:
|
||||
del CONFIG_MAPPING._extra_content["new-model"]
|
||||
if NewConfig in TOKENIZER_MAPPING._extra_content:
|
||||
del TOKENIZER_MAPPING._extra_content[NewConfig]
|
||||
|
||||
@require_tokenizers
|
||||
def test_new_tokenizer_fast_registration(self):
|
||||
try:
|
||||
AutoConfig.register("new-model", NewConfig)
|
||||
|
||||
# Can register in two steps
|
||||
AutoTokenizer.register(NewConfig, slow_tokenizer_class=NewTokenizer)
|
||||
self.assertEqual(TOKENIZER_MAPPING[NewConfig], (NewTokenizer, None))
|
||||
AutoTokenizer.register(NewConfig, fast_tokenizer_class=NewTokenizerFast)
|
||||
self.assertEqual(TOKENIZER_MAPPING[NewConfig], (NewTokenizer, NewTokenizerFast))
|
||||
|
||||
del TOKENIZER_MAPPING._extra_content[NewConfig]
|
||||
# Can register in one step
|
||||
AutoTokenizer.register(NewConfig, slow_tokenizer_class=NewTokenizer, fast_tokenizer_class=NewTokenizerFast)
|
||||
self.assertEqual(TOKENIZER_MAPPING[NewConfig], (NewTokenizer, NewTokenizerFast))
|
||||
|
||||
# Trying to register something existing in the Transformers library will raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
AutoTokenizer.register(BertConfig, fast_tokenizer_class=BertTokenizerFast)
|
||||
|
||||
# We pass through a bert tokenizer fast cause there is no converter slow to fast for our new toknizer
|
||||
# and that model does not have a tokenizer.json
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
bert_tokenizer = BertTokenizerFast.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||
bert_tokenizer.save_pretrained(tmp_dir)
|
||||
tokenizer = NewTokenizerFast.from_pretrained(tmp_dir)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tokenizer.save_pretrained(tmp_dir)
|
||||
|
||||
new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir)
|
||||
self.assertIsInstance(new_tokenizer, NewTokenizerFast)
|
||||
|
||||
new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir, use_fast=False)
|
||||
self.assertIsInstance(new_tokenizer, NewTokenizer)
|
||||
|
||||
finally:
|
||||
if "new-model" in CONFIG_MAPPING._extra_content:
|
||||
del CONFIG_MAPPING._extra_content["new-model"]
|
||||
if NewConfig in TOKENIZER_MAPPING._extra_content:
|
||||
del TOKENIZER_MAPPING._extra_content[NewConfig]
|
||||
|
||||
Reference in New Issue
Block a user