🚨🚨 Fix and simplify attention implementation dispatch and subconfigs handling (#39423)
* first try * Update modeling_utils.py * Update modeling_utils.py * big refactor * Update modeling_utils.py * style * docstrings and simplify inner workings of configs * remove all trace of _internal * Update modeling_utils.py * fix logic error * Update modeling_utils.py * recursive on config * Update configuration_utils.py * fix * Update configuration_dpt.py * Update configuration_utils.py * Update configuration_utils.py * Update modeling_idefics.py * Update modeling_utils.py * fix for old models * more old models fixup * Update modeling_utils.py * Update configuration_utils.py * Remove outdated test * remove the deepcopy!! 🥵🥵 * Update test_modeling_gpt_bigcode.py * fix qwen dispatch * restrict to only models supporting it * style * switch name * Update modeling_utils.py * Update modeling_utils.py * add tests! * fix * rypo * remove bad copies * fix * Update modeling_utils.py * additional check * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * fix * skip
This commit is contained in:
@@ -83,12 +83,13 @@ from transformers.utils.import_utils import (
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
|
||||
|
||||
from test_module.custom_configuration import CustomConfig, NoSuperInitConfig # noqa E402
|
||||
from test_module.custom_configuration import CustomConfig
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from safetensors.torch import save_file as safe_save_file
|
||||
from test_module.custom_modeling import CustomModel, NoSuperInitModel
|
||||
from test_module.custom_modeling import CustomModel
|
||||
from torch import nn
|
||||
|
||||
from transformers import (
|
||||
@@ -732,36 +733,21 @@ class ModelUtilsTest(TestCasePlus):
|
||||
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation)
|
||||
# Ensure the config was set correctly
|
||||
self.assertEqual(config._attn_implementation, requested_attn_implementation)
|
||||
self.assertEqual(config._attn_implementation_internal, requested_attn_implementation)
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
|
||||
|
||||
config = AutoConfig.from_pretrained(TINY_MISTRAL)
|
||||
# When the config is not set, the default is "eager"
|
||||
self.assertEqual(config._attn_implementation, "eager")
|
||||
self.assertEqual(config._attn_implementation_internal, None)
|
||||
self.assertEqual(config._attn_implementation, None)
|
||||
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
|
||||
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
|
||||
|
||||
# Set a nonsense attn_implementation in the config, which should be overridden by the explicit argument
|
||||
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation="foo-bar-baz")
|
||||
self.assertEqual(config._attn_implementation, "foo-bar-baz")
|
||||
self.assertEqual(config._attn_implementation_internal, "foo-bar-baz")
|
||||
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
|
||||
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
|
||||
|
||||
def test_no_super_init_config_and_model(self):
|
||||
config = NoSuperInitConfig(attribute=32)
|
||||
model = NoSuperInitModel(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
new_model = NoSuperInitModel.from_pretrained(tmp_dir)
|
||||
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
def test_checkpoint_sharding_local_bin(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user