[tests] Safety checks on CONFIG_MAPPING
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers.configuration_auto import AutoConfig
|
||||
from transformers.configuration_auto import CONFIG_MAPPING, AutoConfig
|
||||
from transformers.configuration_bert import BertConfig
|
||||
from transformers.configuration_roberta import RobertaConfig
|
||||
|
||||
@@ -42,3 +42,13 @@ class AutoConfigTest(unittest.TestCase):
|
||||
def test_config_for_model_str(self):
|
||||
config = AutoConfig.for_model("roberta")
|
||||
self.assertIsInstance(config, RobertaConfig)
|
||||
|
||||
def test_pattern_matching_fallback(self):
|
||||
"""
|
||||
In cases where config.json doesn't include a model_type,
|
||||
perform a few safety checks on the config mapping's order.
|
||||
"""
|
||||
# no key string should be included in a later key string (typical failure case)
|
||||
keys = list(CONFIG_MAPPING.keys())
|
||||
for i, key in enumerate(keys):
|
||||
self.assertFalse(any(key in later_key for later_key in keys[i+1:]))
|
||||
|
||||
Reference in New Issue
Block a user