[from_pretrained] Allow tokenizer_type ≠ model_type (#6995)

This commit is contained in:
Julien Chaumond
2020-09-09 10:22:59 +02:00
committed by GitHub
parent 03e363f9ae
commit ed71c21d6a
4 changed files with 28 additions and 1 deletions

View File

@@ -27,7 +27,13 @@ from transformers import (
RobertaTokenizer,
RobertaTokenizerFast,
)
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER # noqa: F401
from transformers.configuration_auto import AutoConfig
from transformers.configuration_roberta import RobertaConfig
from transformers.testing_utils import (
DUMMY_DIFF_TOKENIZER_IDENTIFIER,
DUMMY_UNKWOWN_IDENTIFIER,
SMALL_MODEL_IDENTIFIER,
)
from transformers.tokenization_auto import TOKENIZER_MAPPING
@@ -56,6 +62,14 @@ class AutoTokenizerTest(unittest.TestCase):
self.assertIsInstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
self.assertEqual(tokenizer.vocab_size, 20)
def test_tokenizer_from_tokenizer_class(self):
config = AutoConfig.from_pretrained(DUMMY_DIFF_TOKENIZER_IDENTIFIER)
self.assertIsInstance(config, RobertaConfig)
# Check that tokenizer_type ≠ model_type
tokenizer = AutoTokenizer.from_pretrained(DUMMY_DIFF_TOKENIZER_IDENTIFIER, config=config)
self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast))
self.assertEqual(tokenizer.vocab_size, 12)
def test_tokenizer_identifier_with_correct_config(self):
for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]:
tokenizer = tokenizer_class.from_pretrained("wietsedv/bert-base-dutch-cased")