[configuration] allow to overwrite kwargs from subconfigs (#40241)

allow to overwrite kwargs from subconfigs
This commit is contained in:
Raushan Turganbay
2025-08-22 13:31:25 +02:00
committed by GitHub
parent 19ffe0219d
commit 7db228a92a
2 changed files with 40 additions and 1 deletions

View File

@@ -825,8 +825,11 @@ class PretrainedConfig(PushToHubMixin):
if hasattr(config, key):
current_attr = getattr(config, key)
# To authorize passing a custom subconfig as kwarg in models that have nested configs.
# We need to update only custom kwarg values instead and keep other attributes in subconfig.
if isinstance(current_attr, PretrainedConfig) and isinstance(value, dict):
value = current_attr.__class__(**value)
current_attr_updated = current_attr.to_dict()
current_attr_updated.update(value)
value = current_attr.__class__(**current_attr_updated)
setattr(config, key, value)
if key != "dtype":
to_remove.append(key)

View File

@@ -154,6 +154,41 @@ class ConfigTester:
sub_config_loaded_2 = sub_class.from_pretrained(tmpdirname2)
self.parent.assertEqual(sub_config_loaded.to_dict(), sub_config_loaded_2.to_dict())
def create_and_test_config_from_pretrained_custom_kwargs(self):
"""
Tests that passing custom kwargs to the `from_pretrained` will overwrite model's saved config values.
for composite configs. We should overwrite only the requested keys, keeping all values of the
subconfig that are loaded from the checkpoint.
"""
# Check only composite configs. We can't know which attributes each type fo config has so check
# only text config because we are sure that all text configs have a `vocab_size`
config = self.config_class(**self.inputs_dict)
if config.get_text_config() is config or not hasattr(self.parent.model_tester, "get_config"):
return
# First create a config with non-default values and save it. The reload it back with a new
# `vocab_size` and check that all values are loaded from checkpoint and not init from defaults
non_default_inputs = self.parent.model_tester.get_config().to_dict()
config = self.config_class(**non_default_inputs)
original_text_config = config.get_text_config()
text_config_key = [key for key in config if getattr(config, key) is original_text_config]
# The heuristic is a bit brittle so let's just skip the test
if len(text_config_key) != 1:
return
text_config_key = text_config_key[0]
with tempfile.TemporaryDirectory() as tmpdirname:
config.save_pretrained(tmpdirname)
# Set vocab size to 20 tokens and reload from checkpoint and check if all keys/values are identical except for `vocab_size`
config_reloaded = self.config_class.from_pretrained(tmpdirname, **{text_config_key: {"vocab_size": 20}})
original_text_config_dict = original_text_config.to_dict()
original_text_config_dict["vocab_size"] = 20
text_config_reloaded_dict = config_reloaded.get_text_config().to_dict()
self.parent.assertDictEqual(text_config_reloaded_dict, original_text_config_dict)
def create_and_test_config_with_num_labels(self):
config = self.config_class(**self.inputs_dict, num_labels=5)
self.parent.assertEqual(len(config.id2label), 5)
@@ -204,3 +239,4 @@ class ConfigTester:
self.create_and_test_config_with_num_labels()
self.check_config_can_be_init_without_params()
self.check_config_arguments_init()
self.create_and_test_config_from_pretrained_custom_kwargs()