Automatically sort auto mappings (#17250)

* Automatically sort auto mappings

* Better class extraction

* Some auto class magic

* Adapt test and underlying behavior

* Remove re-used config

* Quality
This commit is contained in:
Sylvain Gugger
2022-05-16 13:24:20 -04:00
committed by GitHub
parent 2f611f85e2
commit ddb1a47ec8
14 changed files with 1199 additions and 1094 deletions

View File

@@ -14,6 +14,8 @@
# limitations under the License.
import importlib
import json
import os
import sys
import tempfile
import unittest
@@ -56,14 +58,14 @@ class AutoConfigTest(unittest.TestCase):
self.assertIsInstance(config, RobertaConfig)
def test_pattern_matching_fallback(self):
"""
In cases where config.json doesn't include a model_type,
perform a few safety checks on the config mapping's order.
"""
# no key string should be included in a later key string (typical failure case)
keys = list(CONFIG_MAPPING.keys())
for i, key in enumerate(keys):
self.assertFalse(any(key in later_key for later_key in keys[i + 1 :]))
with tempfile.TemporaryDirectory() as tmp_dir:
# This model name contains bert and roberta, but roberta ends up being picked.
folder = os.path.join(tmp_dir, "fake-roberta")
os.makedirs(folder, exist_ok=True)
with open(os.path.join(folder, "config.json"), "w") as f:
f.write(json.dumps({}))
config = AutoConfig.from_pretrained(folder)
self.assertEqual(type(config), RobertaConfig)
def test_new_config_registration(self):
try: