[configuration] allow to overwrite kwargs from subconfigs (#40241)
allow to overwrite kwargs from subconfigs
This commit is contained in:
committed by
GitHub
parent
19ffe0219d
commit
7db228a92a
@@ -825,8 +825,11 @@ class PretrainedConfig(PushToHubMixin):
|
||||
if hasattr(config, key):
|
||||
current_attr = getattr(config, key)
|
||||
# To authorize passing a custom subconfig as kwarg in models that have nested configs.
|
||||
# We need to update only custom kwarg values instead and keep other attributes in subconfig.
|
||||
if isinstance(current_attr, PretrainedConfig) and isinstance(value, dict):
|
||||
value = current_attr.__class__(**value)
|
||||
current_attr_updated = current_attr.to_dict()
|
||||
current_attr_updated.update(value)
|
||||
value = current_attr.__class__(**current_attr_updated)
|
||||
setattr(config, key, value)
|
||||
if key != "dtype":
|
||||
to_remove.append(key)
|
||||
|
||||
@@ -154,6 +154,41 @@ class ConfigTester:
|
||||
sub_config_loaded_2 = sub_class.from_pretrained(tmpdirname2)
|
||||
self.parent.assertEqual(sub_config_loaded.to_dict(), sub_config_loaded_2.to_dict())
|
||||
|
||||
def create_and_test_config_from_pretrained_custom_kwargs(self):
|
||||
"""
|
||||
Tests that passing custom kwargs to the `from_pretrained` will overwrite model's saved config values.
|
||||
for composite configs. We should overwrite only the requested keys, keeping all values of the
|
||||
subconfig that are loaded from the checkpoint.
|
||||
"""
|
||||
# Check only composite configs. We can't know which attributes each type fo config has so check
|
||||
# only text config because we are sure that all text configs have a `vocab_size`
|
||||
config = self.config_class(**self.inputs_dict)
|
||||
if config.get_text_config() is config or not hasattr(self.parent.model_tester, "get_config"):
|
||||
return
|
||||
|
||||
# First create a config with non-default values and save it. The reload it back with a new
|
||||
# `vocab_size` and check that all values are loaded from checkpoint and not init from defaults
|
||||
non_default_inputs = self.parent.model_tester.get_config().to_dict()
|
||||
config = self.config_class(**non_default_inputs)
|
||||
original_text_config = config.get_text_config()
|
||||
text_config_key = [key for key in config if getattr(config, key) is original_text_config]
|
||||
|
||||
# The heuristic is a bit brittle so let's just skip the test
|
||||
if len(text_config_key) != 1:
|
||||
return
|
||||
|
||||
text_config_key = text_config_key[0]
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
config.save_pretrained(tmpdirname)
|
||||
|
||||
# Set vocab size to 20 tokens and reload from checkpoint and check if all keys/values are identical except for `vocab_size`
|
||||
config_reloaded = self.config_class.from_pretrained(tmpdirname, **{text_config_key: {"vocab_size": 20}})
|
||||
original_text_config_dict = original_text_config.to_dict()
|
||||
original_text_config_dict["vocab_size"] = 20
|
||||
|
||||
text_config_reloaded_dict = config_reloaded.get_text_config().to_dict()
|
||||
self.parent.assertDictEqual(text_config_reloaded_dict, original_text_config_dict)
|
||||
|
||||
def create_and_test_config_with_num_labels(self):
|
||||
config = self.config_class(**self.inputs_dict, num_labels=5)
|
||||
self.parent.assertEqual(len(config.id2label), 5)
|
||||
@@ -204,3 +239,4 @@ class ConfigTester:
|
||||
self.create_and_test_config_with_num_labels()
|
||||
self.check_config_can_be_init_without_params()
|
||||
self.check_config_arguments_init()
|
||||
self.create_and_test_config_from_pretrained_custom_kwargs()
|
||||
|
||||
Reference in New Issue
Block a user