[core] Fix attn_implementation setter with missing sub_configs (#39855)
* fix * add sub_configs * remove case for attention setter * fix None * Add test * Fix sub-configs * fix tests_config * fix consistency * fix fsmt * fix
This commit is contained in:
committed by
GitHub
parent
2a9febd632
commit
16d6faef9a
@@ -141,6 +141,7 @@ class ConfigTester:
|
||||
# Verify that loading with subconfig class results in same dict as if we loaded with general composite config class
|
||||
sub_config_loaded_dict = sub_config_loaded.to_dict()
|
||||
sub_config_loaded_dict.pop("transformers_version", None)
|
||||
general_config_dict[sub_config_key].pop("transformers_version", None)
|
||||
self.parent.assertEqual(sub_config_loaded_dict, general_config_dict[sub_config_key])
|
||||
|
||||
# Verify that the loaded config type is same as in the general config
|
||||
|
||||
@@ -4812,6 +4812,25 @@ class ModelTesterMixin:
|
||||
f"All parameters should be on meta device, but found {unique_devices}.",
|
||||
)
|
||||
|
||||
def test_config_attn_implementation_setter(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def check_attn_implementation_setter(config: PretrainedConfig, attn_implementation: str):
|
||||
if not config._attn_implementation == attn_implementation:
|
||||
raise ValueError(
|
||||
f"Unexpected attn_implementation for config {config.__class__.__name__}: "
|
||||
f"{config._attn_implementation} != {attn_implementation}"
|
||||
)
|
||||
for attribute_value in config.__dict__.values():
|
||||
if isinstance(attribute_value, PretrainedConfig):
|
||||
check_attn_implementation_setter(attribute_value, attn_implementation)
|
||||
|
||||
config._attn_implementation = "eager"
|
||||
check_attn_implementation_setter(config, "eager")
|
||||
|
||||
config._attn_implementation = "sdpa"
|
||||
check_attn_implementation_setter(config, "sdpa")
|
||||
|
||||
def test_internal_model_config_and_subconfig_are_same(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
subconfig_keys = list(config.sub_configs.keys())
|
||||
|
||||
Reference in New Issue
Block a user