Load sub-configs from composite configs (#34410)

* save/load sub-configs

* nit forgot these

* fix copies

* move test to common

* use dict for sub-configs

* add load-save-laod test

* clean up modeling check

* oops this are correct keys

* fix some tests, missed some composite configs

* this model was missed
This commit is contained in:
Raushan Turganbay
2024-11-05 11:34:01 +01:00
committed by GitHub
parent 5e1fd4e204
commit 893ad04fad
78 changed files with 464 additions and 1052 deletions

View File

@@ -3802,22 +3802,18 @@ class ModelTesterMixin:
self.skipTest("Model is not a composite model.")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
sub_configs = {
key: getattr(config, key) for key in config if isinstance(getattr(config, key), PretrainedConfig)
}
# set eager as it will be the one supported in all models
# we just need to test if passing 'attn_implementation' as a dict fails or not
attn_implementation_per_subconfig = {}
for key, sub_config in sub_configs.items():
for key in config.sub_configs.keys():
attn_implementation_per_subconfig[key] = "eager"
config._attn_implementation = attn_implementation_per_subconfig
model = model_class(config)
for key in model.config:
if isinstance(getattr(model.config, key), PretrainedConfig):
sub_config = getattr(model.config, key)
self.assertTrue(sub_config._attn_implementation == "eager")
for key in config.sub_configs.keys():
sub_config = getattr(model.config, key)
self.assertTrue(sub_config._attn_implementation == "eager")
for name, submodule in model.named_modules():
class_name = submodule.__class__.__name__
@@ -3826,7 +3822,7 @@ class ModelTesterMixin:
or "SdpaSelfAttention" in class_name
or "FlashAttention" in class_name
):
raise ValueError("The eager model should not have SDPA/FA2 attention layers")
raise ValueError(f"The eager model should not have SDPA/FA2 attention layers but got {class_name}")
@require_torch_sdpa
def test_sdpa_can_dispatch_non_composite_models(self):