[refactor] set attention implementation (#38974)

* update

* fix some tests

* init from config, changes it in-place, add deepcopy in tests

* fix modernbert

* don't delete thsi config attr

* update

* style and copies

* skip tests in generation

* fix style

* accidentally removed flash-attn-3, revert

* docs

* forgot about flags set to False

* fix copies

* address a few comments

* fix copies

* custom code BC
This commit is contained in:
Raushan Turganbay
2025-07-15 12:34:06 +05:00
committed by GitHub
parent 6017f5e8ed
commit 8d6259b0b8
185 changed files with 451 additions and 776 deletions

View File

@@ -699,7 +699,7 @@ class ModelTesterMixin:
def test_from_pretrained_no_checkpoint(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model = model_class(copy.deepcopy(config))
state_dict = model.state_dict()
new_model = model_class.from_pretrained(
@@ -714,7 +714,7 @@ class ModelTesterMixin:
if model_class._keep_in_fp32_modules is None:
self.skipTest(reason="Model class has no _keep_in_fp32_modules attribute defined")
model = model_class(config)
model = model_class(copy.deepcopy(config))
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
@@ -730,7 +730,7 @@ class ModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model = model_class(copy.deepcopy(config))
_keys_to_ignore_on_save = getattr(model, "_keys_to_ignore_on_save", None)
if _keys_to_ignore_on_save is None:
continue
@@ -766,7 +766,7 @@ class ModelTesterMixin:
continue
config.gradient_checkpointing = True
model = model_class(config)
model = model_class(copy.deepcopy(config))
self.assertTrue(model.is_gradient_checkpointing)
def test_gradient_checkpointing_enable_disable(self):
@@ -777,7 +777,7 @@ class ModelTesterMixin:
continue
# at init model should have gradient checkpointing disabled
model = model_class(config)
model = model_class(copy.deepcopy(config))
self.assertFalse(model.is_gradient_checkpointing)
# check enable works
@@ -810,7 +810,7 @@ class ModelTesterMixin:
continue
# at init model should have gradient checkpointing disabled
model = model_class(config)
model = model_class(copy.deepcopy(config))
self.assertFalse(model.is_gradient_checkpointing)
# check enable works
@@ -871,7 +871,7 @@ class ModelTesterMixin:
# First, initialize the model from config -> this ensure everything is correctly initialized, even if
# _init_weights() does not take all weights into account correctly
model_from_config = model_class(config)
model_from_config = model_class(copy.deepcopy(config))
# Here, passing an empty state dict will force all weights to be moved from meta to cpu, then be initialized
# by _init_weights()
model_from_pretrained = model_class.from_pretrained(None, config=config, state_dict={})
@@ -944,7 +944,7 @@ class ModelTesterMixin:
base_class_copy._init_weights = _mock_init_weights
base_class_copy.init_weights = _mock_all_init_weights
model = model_class(config)
model = model_class(copy.deepcopy(config))
state_dict = model.state_dict()
def check_equal(loaded):
@@ -969,7 +969,7 @@ class ModelTesterMixin:
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model = model_class(config=copy.deepcopy(configs_no_init))
for name, param in model.named_parameters():
if param.requires_grad:
data = torch.flatten(param.data)
@@ -1000,7 +1000,7 @@ class ModelTesterMixin:
self.assertLessEqual(max_diff, 1e-5)
for model_class in self.all_model_classes:
model = model_class(config)
model = model_class(copy.deepcopy(config))
model.to(torch_device)
model.eval()
with torch.no_grad():
@@ -1075,7 +1075,7 @@ class ModelTesterMixin:
if hasattr(self.model_tester, "prepare_config_and_inputs_for_model_class"):
config, batched_input = self.model_tester.prepare_config_and_inputs_for_model_class(model_class)
batched_input_prepared = self._prepare_for_class(batched_input, model_class)
model = model_class(config).to(torch_device).eval()
model = model_class(copy.deepcopy(config)).to(torch_device).eval()
set_model_for_less_flaky_test(model)
batch_size = self.model_tester.batch_size
@@ -1932,7 +1932,7 @@ class ModelTesterMixin:
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model = model_class(copy.deepcopy(config))
model.to(torch_device)
model.eval()
@@ -2061,16 +2061,15 @@ class ModelTesterMixin:
) = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
torch.manual_seed(0)
config = copy.deepcopy(original_config)
model = model_class(config)
model = model_class(copy.deepcopy(original_config))
model.to(torch_device)
model.eval()
hidden_states_no_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
torch.manual_seed(0)
config.chunk_size_feed_forward = 1
model = model_class(config)
original_config.chunk_size_feed_forward = 1
model = model_class(copy.deepcopy(original_config))
model.to(torch_device)
model.eval()
@@ -2445,7 +2444,7 @@ class ModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model = model_class(copy.deepcopy(config))
self.assertIsInstance(model.get_input_embeddings(), nn.Embedding)
new_input_embedding_layer = nn.Embedding(10, 10)
@@ -2505,7 +2504,7 @@ class ModelTesterMixin:
for model_class in self.all_model_classes:
config.torchscript = True
model_not_tied = model_class(config)
model_not_tied = model_class(copy.deepcopy(config))
if model_not_tied.get_output_embeddings() is None:
continue
@@ -2582,7 +2581,7 @@ class ModelTesterMixin:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config.get_text_config().tie_word_embeddings = True
for model_class in self.all_model_classes:
model_tied = model_class(config)
model_tied = model_class(copy.deepcopy(config))
ptrs = collections.defaultdict(list)
for name, tensor in model_tied.state_dict().items():
@@ -2707,7 +2706,7 @@ class ModelTesterMixin:
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config)
model = model_class(copy.deepcopy(config))
model.to(torch_device)
model.eval()
@@ -3033,7 +3032,7 @@ class ModelTesterMixin:
continue
inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval()
model = model_class(copy.deepcopy(config)).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict_class)
@@ -3077,7 +3076,7 @@ class ModelTesterMixin:
continue
inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval()
model = model_class(copy.deepcopy(config)).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict_class)
@@ -3115,7 +3114,7 @@ class ModelTesterMixin:
continue
inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval()
model = model_class(copy.deepcopy(config)).eval()
model = model.to(torch_device)
torch.manual_seed(0)
@@ -3470,7 +3469,7 @@ class ModelTesterMixin:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model = model_class(copy.deepcopy(config))
num_params = model.num_parameters()
assert num_params < 1000000, (
f"{model_class} is too big for the common tests ({num_params})! It should have 1M max."
@@ -3628,7 +3627,7 @@ class ModelTesterMixin:
# set eager as it will be the one supported in all models
# we just need to test if passing 'attn_implementation' as a dict fails or not
attn_implementation_per_subconfig = {}
attn_implementation_per_subconfig = {"": "eager"}
for key in config.sub_configs.keys():
attn_implementation_per_subconfig[key] = "eager"
@@ -4717,7 +4716,7 @@ class ModelTesterMixin:
for model_class in self.all_model_classes:
# If it does not raise here, the test passes
with torch.device("meta"):
_ = model_class(config)
_ = model_class(copy.deepcopy(config))
@require_torch_accelerator
def test_can_load_with_device_context_manager(self):