Load sub-configs from composite configs (#34410)
* save/load sub-configs * nit forgot these * fix copies * move test to common * use dict for sub-configs * add load-save-laod test * clean up modeling check * oops this are correct keys * fix some tests, missed some composite configs * this model was missed
This commit is contained in:
committed by
GitHub
parent
5e1fd4e204
commit
893ad04fad
@@ -17,12 +17,17 @@ import copy
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.utils import direct_transformers_import
|
||||
|
||||
from .utils.test_configuration_utils import config_common_kwargs
|
||||
|
||||
|
||||
transformers_module = direct_transformers_import(Path(__file__).parent)
|
||||
|
||||
|
||||
class ConfigTester:
|
||||
def __init__(self, parent, config_class=None, has_text_modality=True, common_properties=None, **kwargs):
|
||||
self.parent = parent
|
||||
@@ -35,9 +40,10 @@ class ConfigTester:
|
||||
config = self.config_class(**self.inputs_dict)
|
||||
common_properties = (
|
||||
["hidden_size", "num_attention_heads", "num_hidden_layers"]
|
||||
if self.common_properties is None
|
||||
if self.common_properties is None and not self.config_class.sub_configs
|
||||
else self.common_properties
|
||||
)
|
||||
common_properties = [] if common_properties is None else common_properties
|
||||
|
||||
# Add common fields for text models
|
||||
if self.has_text_modality:
|
||||
@@ -110,6 +116,44 @@ class ConfigTester:
|
||||
|
||||
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
|
||||
|
||||
def create_and_test_config_from_and_save_pretrained_composite(self):
|
||||
"""
|
||||
Tests that composite or nested cofigs can be loaded and saved correctly. In case the config
|
||||
has a sub-config, we should be able to call `sub_config.from_pretrained('general_config_file')`
|
||||
and get a result same as if we loaded the whole config and obtained `config.sub_config` from it.
|
||||
"""
|
||||
config = self.config_class(**self.inputs_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
config.save_pretrained(tmpdirname)
|
||||
general_config_loaded = self.config_class.from_pretrained(tmpdirname)
|
||||
general_config_dict = config.to_dict()
|
||||
|
||||
# Iterate over all sub_configs if there are any and load them with their own classes
|
||||
sub_configs = self.config_class.sub_configs
|
||||
for sub_config_key, sub_class in sub_configs.items():
|
||||
if sub_class.__name__ == "AutoConfig":
|
||||
sub_class = sub_class.for_model(**general_config_dict[sub_config_key]).__class__
|
||||
sub_config_loaded = sub_class.from_pretrained(tmpdirname)
|
||||
else:
|
||||
sub_config_loaded = sub_class.from_pretrained(tmpdirname)
|
||||
|
||||
# Pop `transformers_version`, it never exists when a config is part of a general composite config
|
||||
# Verify that loading with subconfig class results in same dict as if we loaded with general composite config class
|
||||
sub_config_loaded_dict = sub_config_loaded.to_dict()
|
||||
sub_config_loaded_dict.pop("transformers_version", None)
|
||||
self.parent.assertEqual(sub_config_loaded_dict, general_config_dict[sub_config_key])
|
||||
|
||||
# Verify that the loaded config type is same as in the general config
|
||||
type_from_general_config = type(getattr(general_config_loaded, sub_config_key))
|
||||
self.parent.assertTrue(isinstance(sub_config_loaded, type_from_general_config))
|
||||
|
||||
# Now save only the sub-config and load it back to make sure the whole load-save-load pipeline works
|
||||
with tempfile.TemporaryDirectory() as tmpdirname2:
|
||||
sub_config_loaded.save_pretrained(tmpdirname2)
|
||||
sub_config_loaded_2 = sub_class.from_pretrained(tmpdirname2)
|
||||
self.parent.assertEqual(sub_config_loaded.to_dict(), sub_config_loaded_2.to_dict())
|
||||
|
||||
def create_and_test_config_with_num_labels(self):
|
||||
config = self.config_class(**self.inputs_dict, num_labels=5)
|
||||
self.parent.assertEqual(len(config.id2label), 5)
|
||||
@@ -128,6 +172,9 @@ class ConfigTester:
|
||||
self.parent.assertIsNotNone(config)
|
||||
|
||||
def check_config_arguments_init(self):
|
||||
if self.config_class.sub_configs:
|
||||
return # TODO: @raushan composite models are not consistent in how they set general params
|
||||
|
||||
kwargs = copy.deepcopy(config_common_kwargs)
|
||||
config = self.config_class(**kwargs)
|
||||
wrong_values = []
|
||||
@@ -153,6 +200,7 @@ class ConfigTester:
|
||||
self.create_and_test_config_to_json_file()
|
||||
self.create_and_test_config_from_and_save_pretrained()
|
||||
self.create_and_test_config_from_and_save_pretrained_subfolder()
|
||||
self.create_and_test_config_from_and_save_pretrained_composite()
|
||||
self.create_and_test_config_with_num_labels()
|
||||
self.check_config_can_be_init_without_params()
|
||||
self.check_config_arguments_init()
|
||||
|
||||
Reference in New Issue
Block a user