From c4ecd234f250cc6272812505b4a9b6cd1e29a816 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 30 Aug 2021 11:55:18 -0400 Subject: [PATCH] Fix AutoTokenizer when no fast tokenizer is available (#13336) * Fix AutoTokenizer when a tokenizer has no fast version * Add test --- src/transformers/models/auto/tokenization_auto.py | 8 ++++---- tests/test_tokenization_auto.py | 6 ++++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 9d6b8e330e..7542781f68 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -229,12 +229,12 @@ def tokenizer_class_from_name(class_name: str): for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items(): if class_name in tokenizers: - break + module_name = model_type_to_module_name(module_name) - module_name = model_type_to_module_name(module_name) + module = importlib.import_module(f".{module_name}", "transformers.models") + return getattr(module, class_name) - module = importlib.import_module(f".{module_name}", "transformers.models") - return getattr(module, class_name) + return None def get_tokenizer_config( diff --git a/tests/test_tokenization_auto.py b/tests/test_tokenization_auto.py index 2b0dffd318..250323a5f9 100644 --- a/tests/test_tokenization_auto.py +++ b/tests/test_tokenization_auto.py @@ -22,6 +22,7 @@ from transformers import ( AutoTokenizer, BertTokenizer, BertTokenizerFast, + CTRLTokenizer, GPT2Tokenizer, GPT2TokenizerFast, PreTrainedTokenizerFast, @@ -162,6 +163,11 @@ class AutoTokenizerTest(unittest.TestCase): self.assertIsInstance(tokenizer2, tokenizer.__class__) self.assertEqual(tokenizer2.vocab_size, 12) + def test_auto_tokenizer_fast_no_slow(self): + tokenizer = AutoTokenizer.from_pretrained("ctrl") + # There is no fast CTRL so this always gives us a slow tokenizer. + self.assertIsInstance(tokenizer, CTRLTokenizer) + def test_get_tokenizer_config(self): # Check we can load the tokenizer config of an online model. config = get_tokenizer_config("bert-base-cased")