From 40d60e15367563b1b8061e5960687aad3b67c73a Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 26 Aug 2021 01:29:14 -0700 Subject: [PATCH] fix `tokenizer_class_from_name` for models with `-` in the name (#13251) * fix tokenizer_class_from_name * Update src/transformers/models/auto/tokenization_auto.py Co-authored-by: Lysandre Debut * add test Co-authored-by: Lysandre Debut --- .../models/auto/tokenization_auto.py | 5 +++- tests/test_tokenization_auto.py | 24 ++++++++++++++++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index eb56c2dde5..69d6f715ff 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -229,7 +229,10 @@ def tokenizer_class_from_name(class_name: str): if class_name in tokenizers: break - module = importlib.import_module(f".{module_name}", "transformers.models") + if module_name == "openai-gpt": + module_name = "openai" + + module = importlib.import_module(f".{module_name.replace('-', '_')}", "transformers.models") return getattr(module, class_name) diff --git a/tests/test_tokenization_auto.py b/tests/test_tokenization_auto.py index f35d0eb5e2..2b0dffd318 100644 --- a/tests/test_tokenization_auto.py +++ b/tests/test_tokenization_auto.py @@ -29,7 +29,11 @@ from transformers import ( RobertaTokenizerFast, ) from transformers.models.auto.configuration_auto import AutoConfig -from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING, get_tokenizer_config +from transformers.models.auto.tokenization_auto import ( + TOKENIZER_MAPPING, + get_tokenizer_config, + tokenizer_class_from_name, +) from transformers.models.roberta.configuration_roberta import RobertaConfig from transformers.testing_utils import ( DUMMY_DIFF_TOKENIZER_IDENTIFIER, @@ -105,6 +109,24 @@ class AutoTokenizerTest(unittest.TestCase): with self.subTest(msg=f"Testing if {child_config.__name__} is child of {parent_config.__name__}"): self.assertFalse(issubclass(child_config, parent_config)) + def test_model_name_edge_cases_in_mappings(self): + # tests: https://github.com/huggingface/transformers/pull/13251 + # 1. models with `-`, e.g. xlm-roberta -> xlm_roberta + # 2. models that don't remap 1-1 from model-name to model file, e.g., openai-gpt -> openai + tokenizers = TOKENIZER_MAPPING.values() + tokenizer_names = [] + + for slow_tok, fast_tok in tokenizers: + if slow_tok is not None: + tokenizer_names.append(slow_tok.__name__) + + if fast_tok is not None: + tokenizer_names.append(fast_tok.__name__) + + for tokenizer_name in tokenizer_names: + # must find the right class + tokenizer_class_from_name(tokenizer_name) + @require_tokenizers def test_from_pretrained_use_fast_toggle(self): self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased", use_fast=False), BertTokenizer)