🚨🚨 Fix and simplify attention implementation dispatch and subconfigs handling (#39423)

* first try

* Update modeling_utils.py

* Update modeling_utils.py

* big refactor

* Update modeling_utils.py

* style

* docstrings and simplify inner workings of configs

* remove all trace of _internal

* Update modeling_utils.py

* fix logic error

* Update modeling_utils.py

* recursive on config

* Update configuration_utils.py

* fix

* Update configuration_dpt.py

* Update configuration_utils.py

* Update configuration_utils.py

* Update modeling_idefics.py

* Update modeling_utils.py

* fix for old models

* more old models fixup

* Update modeling_utils.py

* Update configuration_utils.py

* Remove outdated test

* remove the deepcopy!! 🥵🥵

* Update test_modeling_gpt_bigcode.py

* fix qwen dispatch

* restrict to only models supporting it

* style

* switch name

* Update modeling_utils.py

* Update modeling_utils.py

* add tests!

* fix

* rypo

* remove bad copies

* fix

* Update modeling_utils.py

* additional check

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* fix

* skip
This commit is contained in:
Cyril Vallez
2025-07-18 13:41:54 +02:00
committed by GitHub
parent 2b819ba4e3
commit 4ded9a4113
33 changed files with 472 additions and 323 deletions

View File

@@ -194,7 +194,6 @@ class ConfigTestUtils(unittest.TestCase):
"_name_or_path",
"_commit_hash",
"_attn_implementation_internal",
"_attn_implementation_autoset",
"transformers_version",
],
)

View File

@@ -83,12 +83,13 @@ from transformers.utils.import_utils import (
sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
from test_module.custom_configuration import CustomConfig, NoSuperInitConfig # noqa E402
from test_module.custom_configuration import CustomConfig
if is_torch_available():
import torch
from safetensors.torch import save_file as safe_save_file
from test_module.custom_modeling import CustomModel, NoSuperInitModel
from test_module.custom_modeling import CustomModel
from torch import nn
from transformers import (
@@ -732,36 +733,21 @@ class ModelUtilsTest(TestCasePlus):
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation)
# Ensure the config was set correctly
self.assertEqual(config._attn_implementation, requested_attn_implementation)
self.assertEqual(config._attn_implementation_internal, requested_attn_implementation)
model = AutoModelForCausalLM.from_config(config)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
config = AutoConfig.from_pretrained(TINY_MISTRAL)
# When the config is not set, the default is "eager"
self.assertEqual(config._attn_implementation, "eager")
self.assertEqual(config._attn_implementation_internal, None)
self.assertEqual(config._attn_implementation, None)
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
# Set a nonsense attn_implementation in the config, which should be overridden by the explicit argument
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation="foo-bar-baz")
self.assertEqual(config._attn_implementation, "foo-bar-baz")
self.assertEqual(config._attn_implementation_internal, "foo-bar-baz")
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
def test_no_super_init_config_and_model(self):
config = NoSuperInitConfig(attribute=32)
model = NoSuperInitModel(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
new_model = NoSuperInitModel.from_pretrained(tmp_dir)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
def test_checkpoint_sharding_local_bin(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")