🚨🚨 Fix and simplify attention implementation dispatch and subconfigs handling (#39423)
* first try * Update modeling_utils.py * Update modeling_utils.py * big refactor * Update modeling_utils.py * style * docstrings and simplify inner workings of configs * remove all trace of _internal * Update modeling_utils.py * fix logic error * Update modeling_utils.py * recursive on config * Update configuration_utils.py * fix * Update configuration_dpt.py * Update configuration_utils.py * Update configuration_utils.py * Update modeling_idefics.py * Update modeling_utils.py * fix for old models * more old models fixup * Update modeling_utils.py * Update configuration_utils.py * Remove outdated test * remove the deepcopy!! 🥵🥵 * Update test_modeling_gpt_bigcode.py * fix qwen dispatch * restrict to only models supporting it * style * switch name * Update modeling_utils.py * Update modeling_utils.py * add tests! * fix * rypo * remove bad copies * fix * Update modeling_utils.py * additional check * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * fix * skip
This commit is contained in:
@@ -4783,6 +4783,126 @@ class ModelTesterMixin:
|
||||
f"All parameters should be on meta device, but found {unique_devices}.",
|
||||
)
|
||||
|
||||
def test_internal_model_config_and_subconfig_are_same(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
subconfig_keys = list(config.sub_configs.keys())
|
||||
for model_class in self.all_model_classes:
|
||||
if len(config.sub_configs) == 0:
|
||||
self.skipTest(reason="No subconfigs so the test does not make sense")
|
||||
# Need to deepcopy here to avoid changing the _attn_implementation in-place
|
||||
model = model_class(copy.deepcopy(config))
|
||||
|
||||
for submodule in model.modules():
|
||||
# This is a submodel
|
||||
if isinstance(submodule, PreTrainedModel) and submodule.config.__class__ != model.config.__class__:
|
||||
subconfig_from_model_internal = submodule.config
|
||||
matching_sub_configs = []
|
||||
for subconfig_key in subconfig_keys:
|
||||
# Get the subconfig from the model config
|
||||
subconfig_from_model_config = getattr(model.config, subconfig_key)
|
||||
if subconfig_from_model_config.__class__ == subconfig_from_model_internal.__class__:
|
||||
# Since some composite models have different submodels parameterized by 2 of the same config
|
||||
# class instances, we need to check against a list of matching classes, and check that at least
|
||||
# 1 is the exact object (instead of checking immediately for similar object)
|
||||
matching_sub_configs.append(subconfig_from_model_config)
|
||||
|
||||
# Both should be exactly the same object, that is when instantiating the submodel when should
|
||||
# absolutely not copy the subconfig
|
||||
if len(matching_sub_configs) > 0:
|
||||
self.assertTrue(
|
||||
any(
|
||||
subconfig_from_model_config is subconfig_from_model_internal
|
||||
for subconfig_from_model_config in matching_sub_configs
|
||||
)
|
||||
)
|
||||
|
||||
def test_can_set_attention_dynamically(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._can_set_attn_implementation():
|
||||
self.skipTest(reason="This model does not support setting its attention dynamically")
|
||||
|
||||
# Need to deepcopy here to avoid changing the _attn_implementation in-place
|
||||
model_config = copy.deepcopy(config)
|
||||
# Set eager everywhere (it sets it recursively on subconfigs)
|
||||
model_config._attn_implementation = "eager"
|
||||
model = model_class(model_config)
|
||||
|
||||
# sanity check to make sure everything is correctly eager
|
||||
self.assertTrue(model.config._attn_implementation == "eager")
|
||||
for subconfig_key in model.config.sub_configs:
|
||||
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "eager")
|
||||
|
||||
if not all(
|
||||
submodule._can_set_attn_implementation()
|
||||
for submodule in model.modules()
|
||||
if isinstance(submodule, PreTrainedModel)
|
||||
):
|
||||
self.skipTest(reason="Parts of this model cannot set attention dynamically")
|
||||
# Some old models technically should support switching, but don't have the flags active...
|
||||
if not all(
|
||||
submodule._supports_sdpa for submodule in model.modules() if isinstance(submodule, PreTrainedModel)
|
||||
):
|
||||
self.skipTest(reason="Parts of this model don't support sdpa")
|
||||
|
||||
# Now, set it to sdpa
|
||||
model.set_attn_implementation("sdpa")
|
||||
|
||||
# Check everything was correctly changed
|
||||
self.assertTrue(model.config._attn_implementation == "sdpa")
|
||||
for subconfig_key in model.config.sub_configs:
|
||||
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "sdpa")
|
||||
|
||||
# Check we cannot set it to random values, and it raises a warning (but no crash)
|
||||
with self.assertLogs("transformers.modeling_utils", level="WARNING") as cm:
|
||||
model.set_attn_implementation("foo")
|
||||
self.assertTrue(
|
||||
any(
|
||||
"Impossible to set the requested `attn_implementation`. The following error was captured:"
|
||||
in warning
|
||||
for warning in cm.output
|
||||
)
|
||||
)
|
||||
|
||||
# Should still be sdpa everywhere
|
||||
self.assertTrue(model.config._attn_implementation == "sdpa")
|
||||
for subconfig_key in model.config.sub_configs:
|
||||
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "sdpa")
|
||||
|
||||
def test_can_set_attention_dynamically_composite_model(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._can_set_attn_implementation():
|
||||
self.skipTest(reason="This model does not support setting its attention dynamically")
|
||||
if not self._is_composite:
|
||||
self.skipTest(reason="This model is not composite")
|
||||
|
||||
# Need to deepcopy here to avoid changing the _attn_implementation in-place
|
||||
model_config = copy.deepcopy(config)
|
||||
# Set eager everywhere (it sets it recursively on subconfigs)
|
||||
model_config._attn_implementation = "eager"
|
||||
model = model_class(model_config)
|
||||
|
||||
# sanity check to make sure everything is correctly eager
|
||||
self.assertTrue(model.config._attn_implementation == "eager")
|
||||
for subconfig_key in model.config.sub_configs:
|
||||
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "eager")
|
||||
|
||||
if not all(
|
||||
submodule._can_set_attn_implementation()
|
||||
for submodule in model.modules()
|
||||
if isinstance(submodule, PreTrainedModel)
|
||||
):
|
||||
self.skipTest(reason="Parts of this model cannot set attention dynamically")
|
||||
|
||||
# Now, set only top-most to sdpa (should support it if it supports the dynamic switch)
|
||||
model.set_attn_implementation({"": "sdpa"})
|
||||
|
||||
# Check only top-most was correctly changed
|
||||
self.assertTrue(model.config._attn_implementation == "sdpa")
|
||||
for subconfig_key in model.config.sub_configs:
|
||||
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "eager")
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user