Automatically sort auto mappings (#17250)
* Automatically sort auto mappings * Better class extraction * Some auto class magic * Adapt test and underlying behavior * Remove re-used config * Quality
This commit is contained in:
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -56,14 +58,14 @@ class AutoConfigTest(unittest.TestCase):
|
||||
self.assertIsInstance(config, RobertaConfig)
|
||||
|
||||
def test_pattern_matching_fallback(self):
|
||||
"""
|
||||
In cases where config.json doesn't include a model_type,
|
||||
perform a few safety checks on the config mapping's order.
|
||||
"""
|
||||
# no key string should be included in a later key string (typical failure case)
|
||||
keys = list(CONFIG_MAPPING.keys())
|
||||
for i, key in enumerate(keys):
|
||||
self.assertFalse(any(key in later_key for later_key in keys[i + 1 :]))
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# This model name contains bert and roberta, but roberta ends up being picked.
|
||||
folder = os.path.join(tmp_dir, "fake-roberta")
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
with open(os.path.join(folder, "config.json"), "w") as f:
|
||||
f.write(json.dumps({}))
|
||||
config = AutoConfig.from_pretrained(folder)
|
||||
self.assertEqual(type(config), RobertaConfig)
|
||||
|
||||
def test_new_config_registration(self):
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user