Load sub-configs from composite configs (#34410)
* save/load sub-configs * nit forgot these * fix copies * move test to common * use dict for sub-configs * add load-save-laod test * clean up modeling check * oops this are correct keys * fix some tests, missed some composite configs * this model was missed
This commit is contained in:
committed by
GitHub
parent
5e1fd4e204
commit
893ad04fad
@@ -3802,22 +3802,18 @@ class ModelTesterMixin:
|
||||
self.skipTest("Model is not a composite model.")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
sub_configs = {
|
||||
key: getattr(config, key) for key in config if isinstance(getattr(config, key), PretrainedConfig)
|
||||
}
|
||||
|
||||
# set eager as it will be the one supported in all models
|
||||
# we just need to test if passing 'attn_implementation' as a dict fails or not
|
||||
attn_implementation_per_subconfig = {}
|
||||
for key, sub_config in sub_configs.items():
|
||||
for key in config.sub_configs.keys():
|
||||
attn_implementation_per_subconfig[key] = "eager"
|
||||
|
||||
config._attn_implementation = attn_implementation_per_subconfig
|
||||
model = model_class(config)
|
||||
for key in model.config:
|
||||
if isinstance(getattr(model.config, key), PretrainedConfig):
|
||||
sub_config = getattr(model.config, key)
|
||||
self.assertTrue(sub_config._attn_implementation == "eager")
|
||||
for key in config.sub_configs.keys():
|
||||
sub_config = getattr(model.config, key)
|
||||
self.assertTrue(sub_config._attn_implementation == "eager")
|
||||
|
||||
for name, submodule in model.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
@@ -3826,7 +3822,7 @@ class ModelTesterMixin:
|
||||
or "SdpaSelfAttention" in class_name
|
||||
or "FlashAttention" in class_name
|
||||
):
|
||||
raise ValueError("The eager model should not have SDPA/FA2 attention layers")
|
||||
raise ValueError(f"The eager model should not have SDPA/FA2 attention layers but got {class_name}")
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_dispatch_non_composite_models(self):
|
||||
|
||||
Reference in New Issue
Block a user