fix tokenizer_class_from_name for models with - in the name (#13251)
* fix tokenizer_class_from_name * Update src/transformers/models/auto/tokenization_auto.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * add test Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -229,7 +229,10 @@ def tokenizer_class_from_name(class_name: str):
|
||||
if class_name in tokenizers:
|
||||
break
|
||||
|
||||
module = importlib.import_module(f".{module_name}", "transformers.models")
|
||||
if module_name == "openai-gpt":
|
||||
module_name = "openai"
|
||||
|
||||
module = importlib.import_module(f".{module_name.replace('-', '_')}", "transformers.models")
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
|
||||
@@ -29,7 +29,11 @@ from transformers import (
|
||||
RobertaTokenizerFast,
|
||||
)
|
||||
from transformers.models.auto.configuration_auto import AutoConfig
|
||||
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING, get_tokenizer_config
|
||||
from transformers.models.auto.tokenization_auto import (
|
||||
TOKENIZER_MAPPING,
|
||||
get_tokenizer_config,
|
||||
tokenizer_class_from_name,
|
||||
)
|
||||
from transformers.models.roberta.configuration_roberta import RobertaConfig
|
||||
from transformers.testing_utils import (
|
||||
DUMMY_DIFF_TOKENIZER_IDENTIFIER,
|
||||
@@ -105,6 +109,24 @@ class AutoTokenizerTest(unittest.TestCase):
|
||||
with self.subTest(msg=f"Testing if {child_config.__name__} is child of {parent_config.__name__}"):
|
||||
self.assertFalse(issubclass(child_config, parent_config))
|
||||
|
||||
def test_model_name_edge_cases_in_mappings(self):
|
||||
# tests: https://github.com/huggingface/transformers/pull/13251
|
||||
# 1. models with `-`, e.g. xlm-roberta -> xlm_roberta
|
||||
# 2. models that don't remap 1-1 from model-name to model file, e.g., openai-gpt -> openai
|
||||
tokenizers = TOKENIZER_MAPPING.values()
|
||||
tokenizer_names = []
|
||||
|
||||
for slow_tok, fast_tok in tokenizers:
|
||||
if slow_tok is not None:
|
||||
tokenizer_names.append(slow_tok.__name__)
|
||||
|
||||
if fast_tok is not None:
|
||||
tokenizer_names.append(fast_tok.__name__)
|
||||
|
||||
for tokenizer_name in tokenizer_names:
|
||||
# must find the right class
|
||||
tokenizer_class_from_name(tokenizer_name)
|
||||
|
||||
@require_tokenizers
|
||||
def test_from_pretrained_use_fast_toggle(self):
|
||||
self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased", use_fast=False), BertTokenizer)
|
||||
|
||||
Reference in New Issue
Block a user