[from_pretrained] Allow tokenizer_type ≠ model_type (#6995)
This commit is contained in:
@@ -27,7 +27,13 @@ from transformers import (
|
||||
RobertaTokenizer,
|
||||
RobertaTokenizerFast,
|
||||
)
|
||||
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER # noqa: F401
|
||||
from transformers.configuration_auto import AutoConfig
|
||||
from transformers.configuration_roberta import RobertaConfig
|
||||
from transformers.testing_utils import (
|
||||
DUMMY_DIFF_TOKENIZER_IDENTIFIER,
|
||||
DUMMY_UNKWOWN_IDENTIFIER,
|
||||
SMALL_MODEL_IDENTIFIER,
|
||||
)
|
||||
from transformers.tokenization_auto import TOKENIZER_MAPPING
|
||||
|
||||
|
||||
@@ -56,6 +62,14 @@ class AutoTokenizerTest(unittest.TestCase):
|
||||
self.assertIsInstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
|
||||
self.assertEqual(tokenizer.vocab_size, 20)
|
||||
|
||||
def test_tokenizer_from_tokenizer_class(self):
|
||||
config = AutoConfig.from_pretrained(DUMMY_DIFF_TOKENIZER_IDENTIFIER)
|
||||
self.assertIsInstance(config, RobertaConfig)
|
||||
# Check that tokenizer_type ≠ model_type
|
||||
tokenizer = AutoTokenizer.from_pretrained(DUMMY_DIFF_TOKENIZER_IDENTIFIER, config=config)
|
||||
self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast))
|
||||
self.assertEqual(tokenizer.vocab_size, 12)
|
||||
|
||||
def test_tokenizer_identifier_with_correct_config(self):
|
||||
for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]:
|
||||
tokenizer = tokenizer_class.from_pretrained("wietsedv/bert-base-dutch-cased")
|
||||
|
||||
Reference in New Issue
Block a user