[from_pretrained] Allow tokenizer_type ≠ model_type (#6995)
This commit is contained in:
@@ -190,6 +190,7 @@ class PretrainedConfig(object):
|
|||||||
self.num_labels = kwargs.pop("num_labels", 2)
|
self.num_labels = kwargs.pop("num_labels", 2)
|
||||||
|
|
||||||
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
|
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
|
||||||
|
self.tokenizer_class = kwargs.pop("tokenizer_class", None)
|
||||||
self.prefix = kwargs.pop("prefix", None)
|
self.prefix = kwargs.pop("prefix", None)
|
||||||
self.bos_token_id = kwargs.pop("bos_token_id", None)
|
self.bos_token_id = kwargs.pop("bos_token_id", None)
|
||||||
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from .file_utils import _tf_available, _torch_available, _torch_tpu_available
|
|||||||
|
|
||||||
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
|
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
|
||||||
DUMMY_UNKWOWN_IDENTIFIER = "julien-c/dummy-unknown"
|
DUMMY_UNKWOWN_IDENTIFIER = "julien-c/dummy-unknown"
|
||||||
|
DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer"
|
||||||
# Used to test Auto{Config, Model, Tokenizer} model_type detection.
|
# Used to test Auto{Config, Model, Tokenizer} model_type detection.
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -222,6 +222,17 @@ class AutoTokenizer:
|
|||||||
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||||
|
|
||||||
use_fast = kwargs.pop("use_fast", False)
|
use_fast = kwargs.pop("use_fast", False)
|
||||||
|
|
||||||
|
if config.tokenizer_class is not None:
|
||||||
|
if use_fast and not config.tokenizer_class.endswith("Fast"):
|
||||||
|
tokenizer_class_candidate = f"{config.tokenizer_class}Fast"
|
||||||
|
else:
|
||||||
|
tokenizer_class_candidate = config.tokenizer_class
|
||||||
|
tokenizer_class = globals().get(tokenizer_class_candidate)
|
||||||
|
if tokenizer_class is None:
|
||||||
|
raise ValueError("Tokenizer class {} does not exist or is not currently imported.")
|
||||||
|
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||||
|
|
||||||
for config_class, (tokenizer_class_py, tokenizer_class_fast) in TOKENIZER_MAPPING.items():
|
for config_class, (tokenizer_class_py, tokenizer_class_fast) in TOKENIZER_MAPPING.items():
|
||||||
if isinstance(config, config_class):
|
if isinstance(config, config_class):
|
||||||
if tokenizer_class_fast and use_fast:
|
if tokenizer_class_fast and use_fast:
|
||||||
|
|||||||
@@ -27,7 +27,13 @@ from transformers import (
|
|||||||
RobertaTokenizer,
|
RobertaTokenizer,
|
||||||
RobertaTokenizerFast,
|
RobertaTokenizerFast,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER # noqa: F401
|
from transformers.configuration_auto import AutoConfig
|
||||||
|
from transformers.configuration_roberta import RobertaConfig
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
DUMMY_DIFF_TOKENIZER_IDENTIFIER,
|
||||||
|
DUMMY_UNKWOWN_IDENTIFIER,
|
||||||
|
SMALL_MODEL_IDENTIFIER,
|
||||||
|
)
|
||||||
from transformers.tokenization_auto import TOKENIZER_MAPPING
|
from transformers.tokenization_auto import TOKENIZER_MAPPING
|
||||||
|
|
||||||
|
|
||||||
@@ -56,6 +62,14 @@ class AutoTokenizerTest(unittest.TestCase):
|
|||||||
self.assertIsInstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
|
self.assertIsInstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
|
||||||
self.assertEqual(tokenizer.vocab_size, 20)
|
self.assertEqual(tokenizer.vocab_size, 20)
|
||||||
|
|
||||||
|
def test_tokenizer_from_tokenizer_class(self):
|
||||||
|
config = AutoConfig.from_pretrained(DUMMY_DIFF_TOKENIZER_IDENTIFIER)
|
||||||
|
self.assertIsInstance(config, RobertaConfig)
|
||||||
|
# Check that tokenizer_type ≠ model_type
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(DUMMY_DIFF_TOKENIZER_IDENTIFIER, config=config)
|
||||||
|
self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast))
|
||||||
|
self.assertEqual(tokenizer.vocab_size, 12)
|
||||||
|
|
||||||
def test_tokenizer_identifier_with_correct_config(self):
|
def test_tokenizer_identifier_with_correct_config(self):
|
||||||
for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]:
|
for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]:
|
||||||
tokenizer = tokenizer_class.from_pretrained("wietsedv/bert-base-dutch-cased")
|
tokenizer = tokenizer_class.from_pretrained("wietsedv/bert-base-dutch-cased")
|
||||||
|
|||||||
Reference in New Issue
Block a user