Load sub-configs from composite configs (#34410)

* save/load sub-configs

* nit forgot these

* fix copies

* move test to common

* use dict for sub-configs

* add load-save-laod test

* clean up modeling check

* oops this are correct keys

* fix some tests, missed some composite configs

* this model was missed
This commit is contained in:
Raushan Turganbay
2024-11-05 11:34:01 +01:00
committed by GitHub
parent 5e1fd4e204
commit 893ad04fad
78 changed files with 464 additions and 1052 deletions

View File

@@ -17,12 +17,17 @@ import copy
import json
import os
import tempfile
from pathlib import Path
from transformers import is_torch_available
from transformers.utils import direct_transformers_import
from .utils.test_configuration_utils import config_common_kwargs
transformers_module = direct_transformers_import(Path(__file__).parent)
class ConfigTester:
def __init__(self, parent, config_class=None, has_text_modality=True, common_properties=None, **kwargs):
self.parent = parent
@@ -35,9 +40,10 @@ class ConfigTester:
config = self.config_class(**self.inputs_dict)
common_properties = (
["hidden_size", "num_attention_heads", "num_hidden_layers"]
if self.common_properties is None
if self.common_properties is None and not self.config_class.sub_configs
else self.common_properties
)
common_properties = [] if common_properties is None else common_properties
# Add common fields for text models
if self.has_text_modality:
@@ -110,6 +116,44 @@ class ConfigTester:
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
def create_and_test_config_from_and_save_pretrained_composite(self):
"""
Tests that composite or nested cofigs can be loaded and saved correctly. In case the config
has a sub-config, we should be able to call `sub_config.from_pretrained('general_config_file')`
and get a result same as if we loaded the whole config and obtained `config.sub_config` from it.
"""
config = self.config_class(**self.inputs_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
config.save_pretrained(tmpdirname)
general_config_loaded = self.config_class.from_pretrained(tmpdirname)
general_config_dict = config.to_dict()
# Iterate over all sub_configs if there are any and load them with their own classes
sub_configs = self.config_class.sub_configs
for sub_config_key, sub_class in sub_configs.items():
if sub_class.__name__ == "AutoConfig":
sub_class = sub_class.for_model(**general_config_dict[sub_config_key]).__class__
sub_config_loaded = sub_class.from_pretrained(tmpdirname)
else:
sub_config_loaded = sub_class.from_pretrained(tmpdirname)
# Pop `transformers_version`, it never exists when a config is part of a general composite config
# Verify that loading with subconfig class results in same dict as if we loaded with general composite config class
sub_config_loaded_dict = sub_config_loaded.to_dict()
sub_config_loaded_dict.pop("transformers_version", None)
self.parent.assertEqual(sub_config_loaded_dict, general_config_dict[sub_config_key])
# Verify that the loaded config type is same as in the general config
type_from_general_config = type(getattr(general_config_loaded, sub_config_key))
self.parent.assertTrue(isinstance(sub_config_loaded, type_from_general_config))
# Now save only the sub-config and load it back to make sure the whole load-save-load pipeline works
with tempfile.TemporaryDirectory() as tmpdirname2:
sub_config_loaded.save_pretrained(tmpdirname2)
sub_config_loaded_2 = sub_class.from_pretrained(tmpdirname2)
self.parent.assertEqual(sub_config_loaded.to_dict(), sub_config_loaded_2.to_dict())
def create_and_test_config_with_num_labels(self):
config = self.config_class(**self.inputs_dict, num_labels=5)
self.parent.assertEqual(len(config.id2label), 5)
@@ -128,6 +172,9 @@ class ConfigTester:
self.parent.assertIsNotNone(config)
def check_config_arguments_init(self):
if self.config_class.sub_configs:
return # TODO: @raushan composite models are not consistent in how they set general params
kwargs = copy.deepcopy(config_common_kwargs)
config = self.config_class(**kwargs)
wrong_values = []
@@ -153,6 +200,7 @@ class ConfigTester:
self.create_and_test_config_to_json_file()
self.create_and_test_config_from_and_save_pretrained()
self.create_and_test_config_from_and_save_pretrained_subfolder()
self.create_and_test_config_from_and_save_pretrained_composite()
self.create_and_test_config_with_num_labels()
self.check_config_can_be_init_without_params()
self.check_config_arguments_init()