🚨🚨 Fix and simplify attention implementation dispatch and subconfigs handling (#39423)

* first try

* Update modeling_utils.py

* Update modeling_utils.py

* big refactor

* Update modeling_utils.py

* style

* docstrings and simplify inner workings of configs

* remove all trace of _internal

* Update modeling_utils.py

* fix logic error

* Update modeling_utils.py

* recursive on config

* Update configuration_utils.py

* fix

* Update configuration_dpt.py

* Update configuration_utils.py

* Update configuration_utils.py

* Update modeling_idefics.py

* Update modeling_utils.py

* fix for old models

* more old models fixup

* Update modeling_utils.py

* Update configuration_utils.py

* Remove outdated test

* remove the deepcopy!! 🥵🥵

* Update test_modeling_gpt_bigcode.py

* fix qwen dispatch

* restrict to only models supporting it

* style

* switch name

* Update modeling_utils.py

* Update modeling_utils.py

* add tests!

* fix

* rypo

* remove bad copies

* fix

* Update modeling_utils.py

* additional check

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* fix

* skip
This commit is contained in:
Cyril Vallez
2025-07-18 13:41:54 +02:00
committed by GitHub
parent 2b819ba4e3
commit 4ded9a4113
33 changed files with 472 additions and 323 deletions

View File

@@ -4783,6 +4783,126 @@ class ModelTesterMixin:
f"All parameters should be on meta device, but found {unique_devices}.",
)
def test_internal_model_config_and_subconfig_are_same(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
subconfig_keys = list(config.sub_configs.keys())
for model_class in self.all_model_classes:
if len(config.sub_configs) == 0:
self.skipTest(reason="No subconfigs so the test does not make sense")
# Need to deepcopy here to avoid changing the _attn_implementation in-place
model = model_class(copy.deepcopy(config))
for submodule in model.modules():
# This is a submodel
if isinstance(submodule, PreTrainedModel) and submodule.config.__class__ != model.config.__class__:
subconfig_from_model_internal = submodule.config
matching_sub_configs = []
for subconfig_key in subconfig_keys:
# Get the subconfig from the model config
subconfig_from_model_config = getattr(model.config, subconfig_key)
if subconfig_from_model_config.__class__ == subconfig_from_model_internal.__class__:
# Since some composite models have different submodels parameterized by 2 of the same config
# class instances, we need to check against a list of matching classes, and check that at least
# 1 is the exact object (instead of checking immediately for similar object)
matching_sub_configs.append(subconfig_from_model_config)
# Both should be exactly the same object, that is when instantiating the submodel when should
# absolutely not copy the subconfig
if len(matching_sub_configs) > 0:
self.assertTrue(
any(
subconfig_from_model_config is subconfig_from_model_internal
for subconfig_from_model_config in matching_sub_configs
)
)
def test_can_set_attention_dynamically(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if not model_class._can_set_attn_implementation():
self.skipTest(reason="This model does not support setting its attention dynamically")
# Need to deepcopy here to avoid changing the _attn_implementation in-place
model_config = copy.deepcopy(config)
# Set eager everywhere (it sets it recursively on subconfigs)
model_config._attn_implementation = "eager"
model = model_class(model_config)
# sanity check to make sure everything is correctly eager
self.assertTrue(model.config._attn_implementation == "eager")
for subconfig_key in model.config.sub_configs:
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "eager")
if not all(
submodule._can_set_attn_implementation()
for submodule in model.modules()
if isinstance(submodule, PreTrainedModel)
):
self.skipTest(reason="Parts of this model cannot set attention dynamically")
# Some old models technically should support switching, but don't have the flags active...
if not all(
submodule._supports_sdpa for submodule in model.modules() if isinstance(submodule, PreTrainedModel)
):
self.skipTest(reason="Parts of this model don't support sdpa")
# Now, set it to sdpa
model.set_attn_implementation("sdpa")
# Check everything was correctly changed
self.assertTrue(model.config._attn_implementation == "sdpa")
for subconfig_key in model.config.sub_configs:
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "sdpa")
# Check we cannot set it to random values, and it raises a warning (but no crash)
with self.assertLogs("transformers.modeling_utils", level="WARNING") as cm:
model.set_attn_implementation("foo")
self.assertTrue(
any(
"Impossible to set the requested `attn_implementation`. The following error was captured:"
in warning
for warning in cm.output
)
)
# Should still be sdpa everywhere
self.assertTrue(model.config._attn_implementation == "sdpa")
for subconfig_key in model.config.sub_configs:
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "sdpa")
def test_can_set_attention_dynamically_composite_model(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if not model_class._can_set_attn_implementation():
self.skipTest(reason="This model does not support setting its attention dynamically")
if not self._is_composite:
self.skipTest(reason="This model is not composite")
# Need to deepcopy here to avoid changing the _attn_implementation in-place
model_config = copy.deepcopy(config)
# Set eager everywhere (it sets it recursively on subconfigs)
model_config._attn_implementation = "eager"
model = model_class(model_config)
# sanity check to make sure everything is correctly eager
self.assertTrue(model.config._attn_implementation == "eager")
for subconfig_key in model.config.sub_configs:
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "eager")
if not all(
submodule._can_set_attn_implementation()
for submodule in model.modules()
if isinstance(submodule, PreTrainedModel)
):
self.skipTest(reason="Parts of this model cannot set attention dynamically")
# Now, set only top-most to sdpa (should support it if it supports the dynamic switch)
model.set_attn_implementation({"": "sdpa"})
# Check only top-most was correctly changed
self.assertTrue(model.config._attn_implementation == "sdpa")
for subconfig_key in model.config.sub_configs:
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "eager")
global_rng = random.Random()