Save code of registered custom models (#15379)

* Allow dynamic modules to use relative imports

* Work for configs

* Fix last merge conflict

* Save code of registered custom objects

* Map strings to strings

* Fix test

* Add tokenizer

* Rework tests

* Tests

* Ignore fixtures py files for tests

* Tokenizer test + fix collection

* With full path

* Rework integration

* Fix typo

* Remove changes in conftest

* Test for tokenizers

* Add documentation

* Update docs/source/custom_models.mdx

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Add file structure and file content

* Add more doc

* Style

* Update docs/source/custom_models.mdx

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Address review comments

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Sylvain Gugger
2022-02-02 10:44:37 -05:00
committed by GitHub
parent 623d8cb475
commit 44b21f117b
23 changed files with 630 additions and 295 deletions

View File

@@ -15,8 +15,10 @@
import importlib
import os
import sys
import tempfile
import unittest
from pathlib import Path
import transformers.models.auto
from transformers.models.auto.configuration_auto import CONFIG_MAPPING, AutoConfig
@@ -25,13 +27,14 @@ from transformers.models.roberta.configuration_roberta import RobertaConfig
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
sys.path.append(str(Path(__file__).parent.parent / "utils"))
from test_module.custom_configuration import CustomConfig # noqa E402
SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
class NewModelConfig(BertConfig):
model_type = "new-model"
class AutoConfigTest(unittest.TestCase):
def test_module_spec(self):
self.assertIsNotNone(transformers.models.auto.__spec__)
@@ -65,24 +68,24 @@ class AutoConfigTest(unittest.TestCase):
def test_new_config_registration(self):
try:
AutoConfig.register("new-model", NewModelConfig)
AutoConfig.register("custom", CustomConfig)
# Wrong model type will raise an error
with self.assertRaises(ValueError):
AutoConfig.register("model", NewModelConfig)
AutoConfig.register("model", CustomConfig)
# Trying to register something existing in the Transformers library will raise an error
with self.assertRaises(ValueError):
AutoConfig.register("bert", BertConfig)
# Now that the config is registered, it can be used as any other config with the auto-API
config = NewModelConfig()
config = CustomConfig()
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir)
new_config = AutoConfig.from_pretrained(tmp_dir)
self.assertIsInstance(new_config, NewModelConfig)
self.assertIsInstance(new_config, CustomConfig)
finally:
if "new-model" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["new-model"]
if "custom" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["custom"]
def test_repo_not_found(self):
with self.assertRaisesRegex(

View File

@@ -17,9 +17,11 @@ import copy
import json
import os
import shutil
import sys
import tempfile
import unittest
import unittest.mock
from pathlib import Path
from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError
@@ -28,6 +30,11 @@ from transformers.configuration_utils import PretrainedConfig
from transformers.testing_utils import PASS, USER, is_staging_test
sys.path.append(str(Path(__file__).parent.parent / "utils"))
from test_module.custom_configuration import CustomConfig # noqa E402
config_common_kwargs = {
"return_dict": False,
"output_hidden_states": True,
@@ -192,23 +199,6 @@ class ConfigTester(object):
self.check_config_arguments_init()
class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)
# Make sure this is synchronized with the config above.
FAKE_CONFIG_CODE = """
from transformers import PretrainedConfig
class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)
"""
@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):
@classmethod
@@ -263,20 +253,23 @@ class ConfigPushToHubTester(unittest.TestCase):
self.assertEqual(v, getattr(new_config, k))
def test_push_to_hub_dynamic_config(self):
config = FakeConfig(attribute=42)
config.auto_map = {"AutoConfig": "configuration.FakeConfig"}
CustomConfig.register_for_auto_class()
config = CustomConfig(attribute=42)
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-config", use_auth_token=self._token)
config.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "configuration.py"), "w") as f:
f.write(FAKE_CONFIG_CODE)
# This has added the proper auto_map field to the config
self.assertDictEqual(config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"})
# The code has been copied from fixtures
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "custom_configuration.py")))
repo.push_to_hub()
new_config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-config", trust_remote_code=True)
# Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
self.assertEqual(new_config.__class__.__name__, "FakeConfig")
self.assertEqual(new_config.__class__.__name__, "CustomConfig")
self.assertEqual(new_config.attribute, 42)

View File

@@ -14,9 +14,10 @@
# limitations under the License.
import copy
import os
import sys
import tempfile
import unittest
from pathlib import Path
from transformers import BertConfig, is_torch_available
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
@@ -31,9 +32,15 @@ from transformers.testing_utils import (
from .test_modeling_bert import BertModelTester
sys.path.append(str(Path(__file__).parent.parent / "utils"))
from test_module.custom_configuration import CustomConfig # noqa E402
if is_torch_available():
import torch
from test_module.custom_modeling import CustomModel
from transformers import (
AutoConfig,
AutoModel,
@@ -56,7 +63,6 @@ if is_torch_available():
FunnelModel,
GPT2Config,
GPT2LMHeadModel,
PreTrainedModel,
RobertaForMaskedLM,
T5Config,
T5ForConditionalGeneration,
@@ -81,51 +87,6 @@ if is_torch_available():
from transformers.models.tapas.modeling_tapas import TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST
class NewModelConfig(BertConfig):
model_type = "new-model"
if is_torch_available():
class NewModel(BertModel):
config_class = NewModelConfig
class FakeModel(PreTrainedModel):
config_class = BertConfig
base_model_prefix = "fake"
def __init__(self, config):
super().__init__(config)
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, x):
return self.linear(x)
def _init_weights(self, module):
pass
# Make sure this is synchronized with the model above.
FAKE_MODEL_CODE = """
import torch
from transformers import BertConfig, PreTrainedModel
class FakeModel(PreTrainedModel):
config_class = BertConfig
base_model_prefix = "fake"
def __init__(self, config):
super().__init__(config)
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, x):
return self.linear(x)
def _init_weights(self, module):
pass
"""
@require_torch
class AutoModelTest(unittest.TestCase):
@slow
@@ -325,20 +286,25 @@ class AutoModelTest(unittest.TestCase):
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"
def test_from_pretrained_dynamic_model_local(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
config.auto_map = {"AutoModel": "modeling.FakeModel"}
model = FakeModel(config)
try:
AutoConfig.register("custom", CustomConfig)
AutoModel.register(CustomConfig, CustomModel)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "modeling.py"), "w") as f:
f.write(FAKE_MODEL_CODE)
config = CustomConfig(hidden_size=32)
model = CustomModel(config)
new_model = AutoModel.from_pretrained(tmp_dir, trust_remote_code=True)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
new_model = AutoModel.from_pretrained(tmp_dir, trust_remote_code=True)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
finally:
if "custom" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["custom"]
if CustomConfig in MODEL_MAPPING._extra_content:
del MODEL_MAPPING._extra_content[CustomConfig]
def test_from_pretrained_dynamic_model_distant(self):
model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True)
@@ -349,7 +315,7 @@ class AutoModelTest(unittest.TestCase):
self.assertEqual(model.__class__.__name__, "NewModel")
def test_new_model_registration(self):
AutoConfig.register("new-model", NewModelConfig)
AutoConfig.register("custom", CustomConfig)
auto_classes = [
AutoModel,
@@ -366,26 +332,27 @@ class AutoModelTest(unittest.TestCase):
with self.subTest(auto_class.__name__):
# Wrong config class will raise an error
with self.assertRaises(ValueError):
auto_class.register(BertConfig, NewModel)
auto_class.register(NewModelConfig, NewModel)
auto_class.register(BertConfig, CustomModel)
auto_class.register(CustomConfig, CustomModel)
# Trying to register something existing in the Transformers library will raise an error
with self.assertRaises(ValueError):
auto_class.register(BertConfig, BertModel)
# Now that the config is registered, it can be used as any other config with the auto-API
tiny_config = BertModelTester(self).get_config()
config = NewModelConfig(**tiny_config.to_dict())
config = CustomConfig(**tiny_config.to_dict())
model = auto_class.from_config(config)
self.assertIsInstance(model, NewModel)
self.assertIsInstance(model, CustomModel)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
new_model = auto_class.from_pretrained(tmp_dir)
self.assertIsInstance(new_model, NewModel)
# The model is a CustomModel but from the new dynamically imported class.
self.assertIsInstance(new_model, CustomModel)
finally:
if "new-model" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["new-model"]
if "custom" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["custom"]
for mapping in (
MODEL_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
@@ -395,8 +362,8 @@ class AutoModelTest(unittest.TestCase):
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
):
if NewModelConfig in mapping._extra_content:
del mapping._extra_content[NewModelConfig]
if CustomConfig in mapping._extra_content:
del mapping._extra_content[CustomConfig]
def test_repo_not_found(self):
with self.assertRaisesRegex(

View File

@@ -20,9 +20,11 @@ import json
import os
import os.path
import random
import sys
import tempfile
import unittest
import warnings
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
@@ -55,10 +57,16 @@ from transformers.testing_utils import (
)
sys.path.append(str(Path(__file__).parent.parent / "utils"))
from test_module.custom_configuration import CustomConfig # noqa E402
if is_torch_available():
import torch
from torch import nn
from test_module.custom_modeling import CustomModel
from transformers import (
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
@@ -2109,61 +2117,6 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(model.dtype, torch.float16)
class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)
# Make sure this is synchronized with the config above.
FAKE_CONFIG_CODE = """
from transformers import PretrainedConfig
class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)
"""
if is_torch_available():
class FakeModel(PreTrainedModel):
config_class = BertConfig
base_model_prefix = "fake"
def __init__(self, config):
super().__init__(config)
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, x):
return self.linear(x)
def _init_weights(self, module):
pass
# Make sure this is synchronized with the model above.
FAKE_MODEL_CODE = """
import torch
from transformers import BertConfig, PreTrainedModel
class FakeModel(PreTrainedModel):
config_class = BertConfig
base_model_prefix = "fake"
def __init__(self, config):
super().__init__(config)
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, x):
return self.linear(x)
def _init_weights(self, module):
pass
"""
@require_torch
@is_staging_test
class ModelPushToHubTester(unittest.TestCase):
@@ -2223,62 +2176,29 @@ class ModelPushToHubTester(unittest.TestCase):
self.assertTrue(torch.equal(p1, p2))
def test_push_to_hub_dynamic_model(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
config.auto_map = {"AutoModel": "modeling.FakeModel"}
model = FakeModel(config)
CustomConfig.register_for_auto_class()
CustomModel.register_for_auto_class()
config = CustomConfig(hidden_size=32)
model = CustomModel(config)
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-model", use_auth_token=self._token)
model.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "modeling.py"), "w") as f:
f.write(FAKE_MODEL_CODE)
# checks
self.assertDictEqual(
config.auto_map,
{"AutoConfig": "custom_configuration.CustomConfig", "AutoModel": "custom_modeling.CustomModel"},
)
repo.push_to_hub()
new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
# Can't make an isinstance check because the new_model is from the FakeModel class of a dynamic module
self.assertEqual(new_model.__class__.__name__, "FakeModel")
# Can't make an isinstance check because the new_model is from the CustomModel class of a dynamic module
self.assertEqual(new_model.__class__.__name__, "CustomModel")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model")
config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
new_model = AutoModel.from_config(config, trust_remote_code=True)
self.assertEqual(new_model.__class__.__name__, "FakeModel")
def test_push_to_hub_dynamic_model_and_config(self):
config = FakeConfig(
attribute=42,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
)
config.auto_map = {"AutoConfig": "configuration.FakeConfig", "AutoModel": "modeling.FakeModel"}
model = FakeModel(config)
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-model-config", use_auth_token=self._token)
model.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "configuration.py"), "w") as f:
f.write(FAKE_CONFIG_CODE)
with open(os.path.join(tmp_dir, "modeling.py"), "w") as f:
f.write(FAKE_MODEL_CODE)
repo.push_to_hub()
new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model-config", trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(new_model.config.__class__.__name__, "FakeConfig")
self.assertEqual(new_model.config.attribute, 42)
# Can't make an isinstance check because the new_model is from the FakeModel class of a dynamic module
self.assertEqual(new_model.__class__.__name__, "FakeModel")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model")
new_model = AutoModel.from_config(config, trust_remote_code=True)
self.assertEqual(new_model.__class__.__name__, "FakeModel")
self.assertEqual(new_model.__class__.__name__, "CustomModel")

View File

@@ -15,8 +15,10 @@
import os
import shutil
import sys
import tempfile
import unittest
from pathlib import Path
import pytest
@@ -30,7 +32,6 @@ from transformers import (
CTRLTokenizer,
GPT2Tokenizer,
GPT2TokenizerFast,
PretrainedConfig,
PreTrainedTokenizerFast,
RobertaTokenizer,
RobertaTokenizerFast,
@@ -52,19 +53,14 @@ from transformers.testing_utils import (
)
class NewConfig(PretrainedConfig):
model_type = "new-model"
sys.path.append(str(Path(__file__).parent.parent / "utils"))
class NewTokenizer(BertTokenizer):
pass
from test_module.custom_configuration import CustomConfig # noqa E402
from test_module.custom_tokenization import CustomTokenizer # noqa E402
if is_tokenizers_available():
class NewTokenizerFast(BertTokenizerFast):
slow_tokenizer_class = NewTokenizer
pass
from test_module.custom_tokenization_fast import CustomTokenizerFast
class AutoTokenizerTest(unittest.TestCase):
@@ -250,41 +246,43 @@ class AutoTokenizerTest(unittest.TestCase):
def test_new_tokenizer_registration(self):
try:
AutoConfig.register("new-model", NewConfig)
AutoConfig.register("custom", CustomConfig)
AutoTokenizer.register(NewConfig, slow_tokenizer_class=NewTokenizer)
AutoTokenizer.register(CustomConfig, slow_tokenizer_class=CustomTokenizer)
# Trying to register something existing in the Transformers library will raise an error
with self.assertRaises(ValueError):
AutoTokenizer.register(BertConfig, slow_tokenizer_class=BertTokenizer)
tokenizer = NewTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
tokenizer = CustomTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer.save_pretrained(tmp_dir)
new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir)
self.assertIsInstance(new_tokenizer, NewTokenizer)
self.assertIsInstance(new_tokenizer, CustomTokenizer)
finally:
if "new-model" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["new-model"]
if NewConfig in TOKENIZER_MAPPING._extra_content:
del TOKENIZER_MAPPING._extra_content[NewConfig]
if "custom" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["custom"]
if CustomConfig in TOKENIZER_MAPPING._extra_content:
del TOKENIZER_MAPPING._extra_content[CustomConfig]
@require_tokenizers
def test_new_tokenizer_fast_registration(self):
try:
AutoConfig.register("new-model", NewConfig)
AutoConfig.register("custom", CustomConfig)
# Can register in two steps
AutoTokenizer.register(NewConfig, slow_tokenizer_class=NewTokenizer)
self.assertEqual(TOKENIZER_MAPPING[NewConfig], (NewTokenizer, None))
AutoTokenizer.register(NewConfig, fast_tokenizer_class=NewTokenizerFast)
self.assertEqual(TOKENIZER_MAPPING[NewConfig], (NewTokenizer, NewTokenizerFast))
AutoTokenizer.register(CustomConfig, slow_tokenizer_class=CustomTokenizer)
self.assertEqual(TOKENIZER_MAPPING[CustomConfig], (CustomTokenizer, None))
AutoTokenizer.register(CustomConfig, fast_tokenizer_class=CustomTokenizerFast)
self.assertEqual(TOKENIZER_MAPPING[CustomConfig], (CustomTokenizer, CustomTokenizerFast))
del TOKENIZER_MAPPING._extra_content[NewConfig]
del TOKENIZER_MAPPING._extra_content[CustomConfig]
# Can register in one step
AutoTokenizer.register(NewConfig, slow_tokenizer_class=NewTokenizer, fast_tokenizer_class=NewTokenizerFast)
self.assertEqual(TOKENIZER_MAPPING[NewConfig], (NewTokenizer, NewTokenizerFast))
AutoTokenizer.register(
CustomConfig, slow_tokenizer_class=CustomTokenizer, fast_tokenizer_class=CustomTokenizerFast
)
self.assertEqual(TOKENIZER_MAPPING[CustomConfig], (CustomTokenizer, CustomTokenizerFast))
# Trying to register something existing in the Transformers library will raise an error
with self.assertRaises(ValueError):
@@ -295,22 +293,22 @@ class AutoTokenizerTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dir:
bert_tokenizer = BertTokenizerFast.from_pretrained(SMALL_MODEL_IDENTIFIER)
bert_tokenizer.save_pretrained(tmp_dir)
tokenizer = NewTokenizerFast.from_pretrained(tmp_dir)
tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir)
with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer.save_pretrained(tmp_dir)
new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir)
self.assertIsInstance(new_tokenizer, NewTokenizerFast)
self.assertIsInstance(new_tokenizer, CustomTokenizerFast)
new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir, use_fast=False)
self.assertIsInstance(new_tokenizer, NewTokenizer)
self.assertIsInstance(new_tokenizer, CustomTokenizer)
finally:
if "new-model" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["new-model"]
if NewConfig in TOKENIZER_MAPPING._extra_content:
del TOKENIZER_MAPPING._extra_content[NewConfig]
if "custom" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["custom"]
if CustomConfig in TOKENIZER_MAPPING._extra_content:
del TOKENIZER_MAPPING._extra_content[CustomConfig]
def test_repo_not_found(self):
with self.assertRaisesRegex(

View File

@@ -21,10 +21,12 @@ import os
import pickle
import re
import shutil
import sys
import tempfile
import unittest
from collections import OrderedDict
from itertools import takewhile
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
from huggingface_hub import Repository, delete_repo, login
@@ -67,6 +69,15 @@ if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel
sys.path.append(str(Path(__file__).parent.parent / "utils"))
from test_module.custom_tokenization import CustomTokenizer # noqa E402
if is_tokenizers_available():
from test_module.custom_tokenization_fast import CustomTokenizerFast
NON_ENGLISH_TAGS = ["chinese", "dutch", "french", "finnish", "german", "multilingual"]
SMALL_TRAINING_CORPUS = [
@@ -3690,28 +3701,6 @@ class TokenizerTesterMixin:
self.rust_tokenizer_class.from_pretrained(tmp_dir_2)
class FakeTokenizer(BertTokenizer):
pass
if is_tokenizers_available():
class FakeTokenizerFast(BertTokenizerFast):
pass
# Make sure this is synchronized with the tokenizers above.
FAKE_TOKENIZER_CODE = """
from transformers import BertTokenizer, BertTokenizerFast
class FakeTokenizer(BertTokenizer):
pass
class FakeTokenizerFast(BertTokenizerFast):
pass
"""
@is_staging_test
class TokenizerPushToHubTester(unittest.TestCase):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]
@@ -3766,47 +3755,62 @@ class TokenizerPushToHubTester(unittest.TestCase):
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
@require_tokenizers
def test_push_to_hub_dynamic_tokenizer(self):
CustomTokenizer.register_for_auto_class()
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = FakeTokenizer(vocab_file)
tokenizer = CustomTokenizer(vocab_file)
# No fast custom tokenizer
tokenizer._auto_map = ("tokenizer.FakeTokenizer", None)
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-tokenizer", use_auth_token=self._token)
print(os.listdir((tmp_dir)))
tokenizer.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "tokenizer.py"), "w") as f:
f.write(FAKE_TOKENIZER_CODE)
with open(os.path.join(tmp_dir, "tokenizer_config.json")) as f:
tokenizer_config = json.load(f)
self.assertEqual(tokenizer_config["auto_map"], ["custom_tokenization.CustomTokenizer", None])
repo.push_to_hub()
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "FakeTokenizer")
# Can't make an isinstance check because the new_model.config is from the CustomTokenizer class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
# Fast and slow custom tokenizer
tokenizer._auto_map = ("tokenizer.FakeTokenizer", "tokenizer.FakeTokenizerFast")
CustomTokenizerFast.register_for_auto_class()
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
bert_tokenizer = BertTokenizerFast.from_pretrained(tmp_dir)
bert_tokenizer.save_pretrained(tmp_dir)
tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir)
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-tokenizer", use_auth_token=self._token)
print(os.listdir((tmp_dir)))
tokenizer.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "tokenizer.py"), "w") as f:
f.write(FAKE_TOKENIZER_CODE)
with open(os.path.join(tmp_dir, "tokenizer_config.json")) as f:
tokenizer_config = json.load(f)
self.assertEqual(
tokenizer_config["auto_map"],
["custom_tokenization.CustomTokenizer", "custom_tokenization_fast.CustomTokenizerFast"],
)
repo.push_to_hub()
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "FakeTokenizerFast")
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizerFast")
tokenizer = AutoTokenizer.from_pretrained(
f"{USER}/test-dynamic-tokenizer", use_fast=False, trust_remote_code=True
)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "FakeTokenizer")
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
class TrieTest(unittest.TestCase):