Save code of registered custom models (#15379)
* Allow dynamic modules to use relative imports * Work for configs * Fix last merge conflict * Save code of registered custom objects * Map strings to strings * Fix test * Add tokenizer * Rework tests * Tests * Ignore fixtures py files for tests * Tokenizer test + fix collection * With full path * Rework integration * Fix typo * Remove changes in conftest * Test for tokenizers * Add documentation * Update docs/source/custom_models.mdx Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Add file structure and file content * Add more doc * Style * Update docs/source/custom_models.mdx Co-authored-by: Suraj Patil <surajp815@gmail.com> * Address review comments Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
@@ -15,8 +15,10 @@
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -30,7 +32,6 @@ from transformers import (
|
||||
CTRLTokenizer,
|
||||
GPT2Tokenizer,
|
||||
GPT2TokenizerFast,
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizerFast,
|
||||
RobertaTokenizer,
|
||||
RobertaTokenizerFast,
|
||||
@@ -52,19 +53,14 @@ from transformers.testing_utils import (
|
||||
)
|
||||
|
||||
|
||||
class NewConfig(PretrainedConfig):
|
||||
model_type = "new-model"
|
||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||
|
||||
|
||||
class NewTokenizer(BertTokenizer):
|
||||
pass
|
||||
from test_module.custom_configuration import CustomConfig # noqa E402
|
||||
from test_module.custom_tokenization import CustomTokenizer # noqa E402
|
||||
|
||||
|
||||
if is_tokenizers_available():
|
||||
|
||||
class NewTokenizerFast(BertTokenizerFast):
|
||||
slow_tokenizer_class = NewTokenizer
|
||||
pass
|
||||
from test_module.custom_tokenization_fast import CustomTokenizerFast
|
||||
|
||||
|
||||
class AutoTokenizerTest(unittest.TestCase):
|
||||
@@ -250,41 +246,43 @@ class AutoTokenizerTest(unittest.TestCase):
|
||||
|
||||
def test_new_tokenizer_registration(self):
|
||||
try:
|
||||
AutoConfig.register("new-model", NewConfig)
|
||||
AutoConfig.register("custom", CustomConfig)
|
||||
|
||||
AutoTokenizer.register(NewConfig, slow_tokenizer_class=NewTokenizer)
|
||||
AutoTokenizer.register(CustomConfig, slow_tokenizer_class=CustomTokenizer)
|
||||
# Trying to register something existing in the Transformers library will raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
AutoTokenizer.register(BertConfig, slow_tokenizer_class=BertTokenizer)
|
||||
|
||||
tokenizer = NewTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||
tokenizer = CustomTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tokenizer.save_pretrained(tmp_dir)
|
||||
|
||||
new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir)
|
||||
self.assertIsInstance(new_tokenizer, NewTokenizer)
|
||||
self.assertIsInstance(new_tokenizer, CustomTokenizer)
|
||||
|
||||
finally:
|
||||
if "new-model" in CONFIG_MAPPING._extra_content:
|
||||
del CONFIG_MAPPING._extra_content["new-model"]
|
||||
if NewConfig in TOKENIZER_MAPPING._extra_content:
|
||||
del TOKENIZER_MAPPING._extra_content[NewConfig]
|
||||
if "custom" in CONFIG_MAPPING._extra_content:
|
||||
del CONFIG_MAPPING._extra_content["custom"]
|
||||
if CustomConfig in TOKENIZER_MAPPING._extra_content:
|
||||
del TOKENIZER_MAPPING._extra_content[CustomConfig]
|
||||
|
||||
@require_tokenizers
|
||||
def test_new_tokenizer_fast_registration(self):
|
||||
try:
|
||||
AutoConfig.register("new-model", NewConfig)
|
||||
AutoConfig.register("custom", CustomConfig)
|
||||
|
||||
# Can register in two steps
|
||||
AutoTokenizer.register(NewConfig, slow_tokenizer_class=NewTokenizer)
|
||||
self.assertEqual(TOKENIZER_MAPPING[NewConfig], (NewTokenizer, None))
|
||||
AutoTokenizer.register(NewConfig, fast_tokenizer_class=NewTokenizerFast)
|
||||
self.assertEqual(TOKENIZER_MAPPING[NewConfig], (NewTokenizer, NewTokenizerFast))
|
||||
AutoTokenizer.register(CustomConfig, slow_tokenizer_class=CustomTokenizer)
|
||||
self.assertEqual(TOKENIZER_MAPPING[CustomConfig], (CustomTokenizer, None))
|
||||
AutoTokenizer.register(CustomConfig, fast_tokenizer_class=CustomTokenizerFast)
|
||||
self.assertEqual(TOKENIZER_MAPPING[CustomConfig], (CustomTokenizer, CustomTokenizerFast))
|
||||
|
||||
del TOKENIZER_MAPPING._extra_content[NewConfig]
|
||||
del TOKENIZER_MAPPING._extra_content[CustomConfig]
|
||||
# Can register in one step
|
||||
AutoTokenizer.register(NewConfig, slow_tokenizer_class=NewTokenizer, fast_tokenizer_class=NewTokenizerFast)
|
||||
self.assertEqual(TOKENIZER_MAPPING[NewConfig], (NewTokenizer, NewTokenizerFast))
|
||||
AutoTokenizer.register(
|
||||
CustomConfig, slow_tokenizer_class=CustomTokenizer, fast_tokenizer_class=CustomTokenizerFast
|
||||
)
|
||||
self.assertEqual(TOKENIZER_MAPPING[CustomConfig], (CustomTokenizer, CustomTokenizerFast))
|
||||
|
||||
# Trying to register something existing in the Transformers library will raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -295,22 +293,22 @@ class AutoTokenizerTest(unittest.TestCase):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
bert_tokenizer = BertTokenizerFast.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||
bert_tokenizer.save_pretrained(tmp_dir)
|
||||
tokenizer = NewTokenizerFast.from_pretrained(tmp_dir)
|
||||
tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tokenizer.save_pretrained(tmp_dir)
|
||||
|
||||
new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir)
|
||||
self.assertIsInstance(new_tokenizer, NewTokenizerFast)
|
||||
self.assertIsInstance(new_tokenizer, CustomTokenizerFast)
|
||||
|
||||
new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir, use_fast=False)
|
||||
self.assertIsInstance(new_tokenizer, NewTokenizer)
|
||||
self.assertIsInstance(new_tokenizer, CustomTokenizer)
|
||||
|
||||
finally:
|
||||
if "new-model" in CONFIG_MAPPING._extra_content:
|
||||
del CONFIG_MAPPING._extra_content["new-model"]
|
||||
if NewConfig in TOKENIZER_MAPPING._extra_content:
|
||||
del TOKENIZER_MAPPING._extra_content[NewConfig]
|
||||
if "custom" in CONFIG_MAPPING._extra_content:
|
||||
del CONFIG_MAPPING._extra_content["custom"]
|
||||
if CustomConfig in TOKENIZER_MAPPING._extra_content:
|
||||
del TOKENIZER_MAPPING._extra_content[CustomConfig]
|
||||
|
||||
def test_repo_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
|
||||
Reference in New Issue
Block a user