🔴🔴🔴 [Attention] Refactor Attention Interface for Bart-based Models (#38108)

* starting attn refactor for encoder decoder models via bart (eager + sdpa)

* flash attention works, remove unnecessary code

* flex attention support for bart!, gotta check if the renaming is not too aggressive

* some comments

* skip flex grad test for standalone as done with the other test

* revert flex attn rename (for now), sdpa simplify, and todos

* more todos

* refactor mask creation for reuse

* modular attempt at biogpt

* first batch of other models

* fix attn dropout

* fix autoformer copies

* hubert

* another batch of models

* copies/style + last round of bart models --> whisper next?

* remove unnecessary _reshape function and remove copy to whisper

* add skip for decoder-only models out of enc-dec (same as in bart)

* bring back licences

* remove comment, added to pr read instead

* mostly docs

* disable sew flex attn as it's unclear attn mask for now

* oops

* test fixes for enc-dec

* torch fx fixes + try at flex attn

* skip on mbart

* some more fixes

* musicgen skip / delete old attn class logic + sdpa compose compile skip

* disable flex attn for musicgen, not worth the effort

* more fixes and style

* flex attention test for dropout and encoder decoder that dont have main input names

* informer fixes

* the weirdest thing I've encountered yet...

* style

* remove empty tensor attempt, found core root in previous commits

* disable time series due to tests being very text centric on inputs

* add speech to text to be ignoring the other attns, also due to tests

* update docs

* remaining issues resolved ?

* update docs for current state --> nllb moe and pegasus x sdpa is questionable :D

* some models have not set the is_causal flag...

* change dtype in softmax tol old behaviour + some modular fixes

* I hate it but it is what it is

* fixes from main for bart

* forgot this one

* some model fixes

* style

* current status

* marian works now

* fixing some copies

* some copy fixes + time series x informer

* last models possibly and fixes on style/copies

* some post merge fixes

* more fixes

* make attention interface callable and move warnings there

* style lol

* add comment to "unsupported"

* remove callable interface and change interface warnings + some copies

* fix

* ternary is ugly af, make it simpler

* how did that happen

* fix flex attn test

* failing the test

* no more fallback! fixing copies next

* style + attn fixed

* fixing copies and mask creation

* wrong copy

* fixup tests and disable flex attn for now

* fixup last tests?
This commit is contained in:
Anton Vlasjuk
2025-05-22 17:12:58 +02:00
committed by GitHub
parent 9895819514
commit d95c864a25
75 changed files with 8445 additions and 6342 deletions

View File

@@ -958,6 +958,8 @@ class ModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
# force eager attention to support output attentions
config._attn_implementation = "eager"
seq_len = getattr(self.model_tester, "seq_length", None)
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
@@ -1106,7 +1108,11 @@ class ModelTesterMixin:
configs_no_init.torchscript = True
for model_class in self.all_model_classes:
for attn_implementation in ["eager", "sdpa"]:
if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()):
if (
attn_implementation == "sdpa"
and (not model_class._supports_sdpa or not is_torch_sdpa_available())
or config.output_attentions
):
continue
configs_no_init._attn_implementation = attn_implementation
@@ -1708,6 +1714,10 @@ class ModelTesterMixin:
config.output_hidden_states = True
config.output_attentions = self.has_attentions
# force eager attention to support output attentions
if self.has_attentions:
config._attn_implementation = "eager"
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
@@ -4555,13 +4565,26 @@ class ModelTesterMixin:
# TODO: raushan, fix for composite models after making VLMs support new attn API
if not model_class._supports_flex_attn or self._is_composite:
self.skipTest(reason="This model does not support flex attention")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config._attn_implementation = "flex_attention"
model = model_class(config).to(device=torch_device, dtype=torch.float16)
# Flex Attention can not use dropout
if hasattr(config, "attention_dropout"):
config.attention_dropout = 0
if hasattr(config, "attention_probs_dropout_prob"):
config.attention_probs_dropout_prob = 0
model = model_class(config).to(device=torch_device)
self.assertTrue(model.config._attn_implementation == "flex_attention")
# Elaborate workaround for encoder-decoder models as some do not specify their main input
dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)}
if config.is_encoder_decoder:
dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"]
dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"]
# If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605)
_ = model(inputs_dict["input_ids"].to(torch_device))
_ = model(**dummy_inputs)
def test_generation_tester_mixin_inheritance(self):
"""