🚨🚨 Fix and simplify attention implementation dispatch and subconfigs handling (#39423)
* first try * Update modeling_utils.py * Update modeling_utils.py * big refactor * Update modeling_utils.py * style * docstrings and simplify inner workings of configs * remove all trace of _internal * Update modeling_utils.py * fix logic error * Update modeling_utils.py * recursive on config * Update configuration_utils.py * fix * Update configuration_dpt.py * Update configuration_utils.py * Update configuration_utils.py * Update modeling_idefics.py * Update modeling_utils.py * fix for old models * more old models fixup * Update modeling_utils.py * Update configuration_utils.py * Remove outdated test * remove the deepcopy!! 🥵🥵 * Update test_modeling_gpt_bigcode.py * fix qwen dispatch * restrict to only models supporting it * style * switch name * Update modeling_utils.py * Update modeling_utils.py * add tests! * fix * rypo * remove bad copies * fix * Update modeling_utils.py * additional check * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * fix * skip
This commit is contained in:
@@ -1085,6 +1085,10 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
@unittest.skip("T5 backbone deepcopies the configs, and fixing it would be more involved")
|
||||
def test_internal_model_config_and_subconfig_are_same(self):
|
||||
pass
|
||||
|
||||
|
||||
class Blip2TextModelWithProjectionTester:
|
||||
def __init__(self, parent, vision_kwargs=None, qformer_kwargs=None, is_training=True):
|
||||
|
||||
@@ -542,6 +542,8 @@ class GPTBigCodeMQATest(unittest.TestCase):
|
||||
attn_pdrop=0,
|
||||
resid_pdrop=0,
|
||||
)
|
||||
# We need to set it here as it's normally set by the Model's __init__
|
||||
config._attn_implementation = "sdpa"
|
||||
return GPTBigCodeAttention(config)
|
||||
|
||||
@parameterized.expand([(seed, is_train_mode) for seed in range(5) for is_train_mode in [True, False]])
|
||||
|
||||
@@ -4783,6 +4783,126 @@ class ModelTesterMixin:
|
||||
f"All parameters should be on meta device, but found {unique_devices}.",
|
||||
)
|
||||
|
||||
def test_internal_model_config_and_subconfig_are_same(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
subconfig_keys = list(config.sub_configs.keys())
|
||||
for model_class in self.all_model_classes:
|
||||
if len(config.sub_configs) == 0:
|
||||
self.skipTest(reason="No subconfigs so the test does not make sense")
|
||||
# Need to deepcopy here to avoid changing the _attn_implementation in-place
|
||||
model = model_class(copy.deepcopy(config))
|
||||
|
||||
for submodule in model.modules():
|
||||
# This is a submodel
|
||||
if isinstance(submodule, PreTrainedModel) and submodule.config.__class__ != model.config.__class__:
|
||||
subconfig_from_model_internal = submodule.config
|
||||
matching_sub_configs = []
|
||||
for subconfig_key in subconfig_keys:
|
||||
# Get the subconfig from the model config
|
||||
subconfig_from_model_config = getattr(model.config, subconfig_key)
|
||||
if subconfig_from_model_config.__class__ == subconfig_from_model_internal.__class__:
|
||||
# Since some composite models have different submodels parameterized by 2 of the same config
|
||||
# class instances, we need to check against a list of matching classes, and check that at least
|
||||
# 1 is the exact object (instead of checking immediately for similar object)
|
||||
matching_sub_configs.append(subconfig_from_model_config)
|
||||
|
||||
# Both should be exactly the same object, that is when instantiating the submodel when should
|
||||
# absolutely not copy the subconfig
|
||||
if len(matching_sub_configs) > 0:
|
||||
self.assertTrue(
|
||||
any(
|
||||
subconfig_from_model_config is subconfig_from_model_internal
|
||||
for subconfig_from_model_config in matching_sub_configs
|
||||
)
|
||||
)
|
||||
|
||||
def test_can_set_attention_dynamically(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._can_set_attn_implementation():
|
||||
self.skipTest(reason="This model does not support setting its attention dynamically")
|
||||
|
||||
# Need to deepcopy here to avoid changing the _attn_implementation in-place
|
||||
model_config = copy.deepcopy(config)
|
||||
# Set eager everywhere (it sets it recursively on subconfigs)
|
||||
model_config._attn_implementation = "eager"
|
||||
model = model_class(model_config)
|
||||
|
||||
# sanity check to make sure everything is correctly eager
|
||||
self.assertTrue(model.config._attn_implementation == "eager")
|
||||
for subconfig_key in model.config.sub_configs:
|
||||
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "eager")
|
||||
|
||||
if not all(
|
||||
submodule._can_set_attn_implementation()
|
||||
for submodule in model.modules()
|
||||
if isinstance(submodule, PreTrainedModel)
|
||||
):
|
||||
self.skipTest(reason="Parts of this model cannot set attention dynamically")
|
||||
# Some old models technically should support switching, but don't have the flags active...
|
||||
if not all(
|
||||
submodule._supports_sdpa for submodule in model.modules() if isinstance(submodule, PreTrainedModel)
|
||||
):
|
||||
self.skipTest(reason="Parts of this model don't support sdpa")
|
||||
|
||||
# Now, set it to sdpa
|
||||
model.set_attn_implementation("sdpa")
|
||||
|
||||
# Check everything was correctly changed
|
||||
self.assertTrue(model.config._attn_implementation == "sdpa")
|
||||
for subconfig_key in model.config.sub_configs:
|
||||
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "sdpa")
|
||||
|
||||
# Check we cannot set it to random values, and it raises a warning (but no crash)
|
||||
with self.assertLogs("transformers.modeling_utils", level="WARNING") as cm:
|
||||
model.set_attn_implementation("foo")
|
||||
self.assertTrue(
|
||||
any(
|
||||
"Impossible to set the requested `attn_implementation`. The following error was captured:"
|
||||
in warning
|
||||
for warning in cm.output
|
||||
)
|
||||
)
|
||||
|
||||
# Should still be sdpa everywhere
|
||||
self.assertTrue(model.config._attn_implementation == "sdpa")
|
||||
for subconfig_key in model.config.sub_configs:
|
||||
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "sdpa")
|
||||
|
||||
def test_can_set_attention_dynamically_composite_model(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._can_set_attn_implementation():
|
||||
self.skipTest(reason="This model does not support setting its attention dynamically")
|
||||
if not self._is_composite:
|
||||
self.skipTest(reason="This model is not composite")
|
||||
|
||||
# Need to deepcopy here to avoid changing the _attn_implementation in-place
|
||||
model_config = copy.deepcopy(config)
|
||||
# Set eager everywhere (it sets it recursively on subconfigs)
|
||||
model_config._attn_implementation = "eager"
|
||||
model = model_class(model_config)
|
||||
|
||||
# sanity check to make sure everything is correctly eager
|
||||
self.assertTrue(model.config._attn_implementation == "eager")
|
||||
for subconfig_key in model.config.sub_configs:
|
||||
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "eager")
|
||||
|
||||
if not all(
|
||||
submodule._can_set_attn_implementation()
|
||||
for submodule in model.modules()
|
||||
if isinstance(submodule, PreTrainedModel)
|
||||
):
|
||||
self.skipTest(reason="Parts of this model cannot set attention dynamically")
|
||||
|
||||
# Now, set only top-most to sdpa (should support it if it supports the dynamic switch)
|
||||
model.set_attn_implementation({"": "sdpa"})
|
||||
|
||||
# Check only top-most was correctly changed
|
||||
self.assertTrue(model.config._attn_implementation == "sdpa")
|
||||
for subconfig_key in model.config.sub_configs:
|
||||
self.assertTrue(getattr(model.config, subconfig_key)._attn_implementation == "eager")
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
@@ -194,7 +194,6 @@ class ConfigTestUtils(unittest.TestCase):
|
||||
"_name_or_path",
|
||||
"_commit_hash",
|
||||
"_attn_implementation_internal",
|
||||
"_attn_implementation_autoset",
|
||||
"transformers_version",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -83,12 +83,13 @@ from transformers.utils.import_utils import (
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
|
||||
|
||||
from test_module.custom_configuration import CustomConfig, NoSuperInitConfig # noqa E402
|
||||
from test_module.custom_configuration import CustomConfig
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from safetensors.torch import save_file as safe_save_file
|
||||
from test_module.custom_modeling import CustomModel, NoSuperInitModel
|
||||
from test_module.custom_modeling import CustomModel
|
||||
from torch import nn
|
||||
|
||||
from transformers import (
|
||||
@@ -732,36 +733,21 @@ class ModelUtilsTest(TestCasePlus):
|
||||
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation)
|
||||
# Ensure the config was set correctly
|
||||
self.assertEqual(config._attn_implementation, requested_attn_implementation)
|
||||
self.assertEqual(config._attn_implementation_internal, requested_attn_implementation)
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
|
||||
|
||||
config = AutoConfig.from_pretrained(TINY_MISTRAL)
|
||||
# When the config is not set, the default is "eager"
|
||||
self.assertEqual(config._attn_implementation, "eager")
|
||||
self.assertEqual(config._attn_implementation_internal, None)
|
||||
self.assertEqual(config._attn_implementation, None)
|
||||
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
|
||||
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
|
||||
|
||||
# Set a nonsense attn_implementation in the config, which should be overridden by the explicit argument
|
||||
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation="foo-bar-baz")
|
||||
self.assertEqual(config._attn_implementation, "foo-bar-baz")
|
||||
self.assertEqual(config._attn_implementation_internal, "foo-bar-baz")
|
||||
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
|
||||
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
|
||||
|
||||
def test_no_super_init_config_and_model(self):
|
||||
config = NoSuperInitConfig(attribute=32)
|
||||
model = NoSuperInitModel(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
new_model = NoSuperInitModel.from_pretrained(tmp_dir)
|
||||
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
def test_checkpoint_sharding_local_bin(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user