diff --git a/tests/test_configuration_auto.py b/tests/test_configuration_auto.py index a6f3376d75..b1ac22ed6c 100644 --- a/tests/test_configuration_auto.py +++ b/tests/test_configuration_auto.py @@ -16,7 +16,7 @@ import os import unittest -from transformers.configuration_auto import AutoConfig +from transformers.configuration_auto import CONFIG_MAPPING, AutoConfig from transformers.configuration_bert import BertConfig from transformers.configuration_roberta import RobertaConfig @@ -42,3 +42,13 @@ class AutoConfigTest(unittest.TestCase): def test_config_for_model_str(self): config = AutoConfig.for_model("roberta") self.assertIsInstance(config, RobertaConfig) + + def test_pattern_matching_fallback(self): + """ + In cases where config.json doesn't include a model_type, + perform a few safety checks on the config mapping's order. + """ + # no key string should be included in a later key string (typical failure case) + keys = list(CONFIG_MAPPING.keys()) + for i, key in enumerate(keys): + self.assertFalse(any(key in later_key for later_key in keys[i+1:]))