🔴🔴🔴 [Attention] Refactor Attention Interface for Bart-based Models (#38108)

* starting attn refactor for encoder decoder models via bart (eager + sdpa)

* flash attention works, remove unnecessary code

* flex attention support for bart!, gotta check if the renaming is not too aggressive

* some comments

* skip flex grad test for standalone as done with the other test

* revert flex attn rename (for now), sdpa simplify, and todos

* more todos

* refactor mask creation for reuse

* modular attempt at biogpt

* first batch of other models

* fix attn dropout

* fix autoformer copies

* hubert

* another batch of models

* copies/style + last round of bart models --> whisper next?

* remove unnecessary _reshape function and remove copy to whisper

* add skip for decoder-only models out of enc-dec (same as in bart)

* bring back licences

* remove comment, added to pr read instead

* mostly docs

* disable sew flex attn as it's unclear attn mask for now

* oops

* test fixes for enc-dec

* torch fx fixes + try at flex attn

* skip on mbart

* some more fixes

* musicgen skip / delete old attn class logic + sdpa compose compile skip

* disable flex attn for musicgen, not worth the effort

* more fixes and style

* flex attention test for dropout and encoder decoder that dont have main input names

* informer fixes

* the weirdest thing I've encountered yet...

* style

* remove empty tensor attempt, found core root in previous commits

* disable time series due to tests being very text centric on inputs

* add speech to text to be ignoring the other attns, also due to tests

* update docs

* remaining issues resolved ?

* update docs for current state --> nllb moe and pegasus x sdpa is questionable :D

* some models have not set the is_causal flag...

* change dtype in softmax tol old behaviour + some modular fixes

* I hate it but it is what it is

* fixes from main for bart

* forgot this one

* some model fixes

* style

* current status

* marian works now

* fixing some copies

* some copy fixes + time series x informer

* last models possibly and fixes on style/copies

* some post merge fixes

* more fixes

* make attention interface callable and move warnings there

* style lol

* add comment to "unsupported"

* remove callable interface and change interface warnings + some copies

* fix

* ternary is ugly af, make it simpler

* how did that happen

* fix flex attn test

* failing the test

* no more fallback! fixing copies next

* style + attn fixed

* fixing copies and mask creation

* wrong copy

* fixup tests and disable flex attn for now

* fixup last tests?
This commit is contained in:
Anton Vlasjuk
2025-05-22 17:12:58 +02:00
committed by GitHub
parent 9895819514
commit d95c864a25
75 changed files with 8445 additions and 6342 deletions

View File

@@ -1521,3 +1521,7 @@ class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, un
@unittest.skip(reason="Decoder cannot keep gradients")
def test_retain_grad_hidden_states_attentions(self):
return
@unittest.skip(reason="Decoder cannot keep gradients")
def test_flex_attention_with_grads():
return

View File

@@ -554,3 +554,7 @@ class BlenderbotStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMix
@unittest.skip(reason="decoder cannot keep gradients")
def test_retain_grad_hidden_states_attentions(self):
return
@unittest.skip(reason="Decoder cannot keep gradients")
def test_flex_attention_with_grads():
return

View File

@@ -563,3 +563,7 @@ class BlenderbotSmallStandaloneDecoderModelTest(ModelTesterMixin, GenerationTest
@unittest.skip(reason="decoder cannot keep gradients")
def test_retain_grad_hidden_states_attentions(self):
return
@unittest.skip(reason="Decoder cannot keep gradients")
def test_flex_attention_with_grads():
return

View File

@@ -420,6 +420,9 @@ class Data2VecAudioModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Tes
config.output_hidden_states = True
config.output_attentions = True
# force eager attention to support output attentions
config._attn_implementation = "eager"
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)

View File

@@ -412,6 +412,10 @@ class EncoderDecoderMixin:
labels,
**kwargs,
):
# force eager attention to support output attentions
config._attn_implementation = "eager"
decoder_config._attn_implementation = "eager"
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
@@ -445,6 +449,10 @@ class EncoderDecoderMixin:
# config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded
# from the inner models' configurations.
# force eager attention to support output attentions
config._attn_implementation = "eager"
decoder_config._attn_implementation = "eager"
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)

View File

@@ -370,6 +370,9 @@ class HubertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
config.output_hidden_states = True
config.output_attentions = True
# force eager attention to support output attentions
config._attn_implementation = "eager"
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
@@ -632,6 +635,9 @@ class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase):
config.output_hidden_states = True
config.output_attentions = True
# force eager attention to support output attentions
config._attn_implementation = "eager"
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)

View File

@@ -357,7 +357,8 @@ class M2M100ModelIntegrationTests(unittest.TestCase):
self.assertEqual(output.shape, expected_shape)
# change to expected output here
expected_slice = torch.tensor(
[[-0.7780, -0.1676, 0.1038], [-6.7556, -1.3992, 0.0567], [-7.5383, -0.5920, -0.2779]], device=torch_device
[[[-0.7780, -0.1676, 0.1038], [-6.7556, -1.3992, 0.0567], [-7.5383, -0.5920, -0.2779]]],
device=torch_device,
)
torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
@@ -374,7 +375,8 @@ class M2M100ModelIntegrationTests(unittest.TestCase):
self.assertEqual(output.shape, expected_shape)
# change to expected output here
expected_slice = torch.tensor(
[[-1.0448, -1.0411, 3.7992], [-3.2191, -3.2386, -1.3451], [-3.6210, -3.5993, 0.4925]], device=torch_device
[[[-1.0448, -1.0411, 3.7992], [-3.2191, -3.2386, -1.3451], [-3.6210, -3.5993, 0.4925]]],
device=torch_device,
)
torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
@@ -426,7 +428,7 @@ class M2M100ModelIntegrationTests(unittest.TestCase):
Overwriting the common test as the test is flaky on tiny models
"""
model = M2M100ForConditionalGeneration.from_pretrained(
"facebook/m2m100_418M", torch_dtype=torch.float16, attn_implementation="flash_attention_2"
"facebook/m2m100_418M", attn_implementation="flash_attention_2", torch_dtype=torch.float16
).to(torch_device)
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="fr", tgt_lang="en")

View File

@@ -850,3 +850,7 @@ class MarianStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
@unittest.skip(reason="Decoder cannot keep gradients")
def test_retain_grad_hidden_states_attentions(self):
return
@unittest.skip(reason="Decoder cannot keep gradients")
def test_flex_attention_with_grads():
return

View File

@@ -735,3 +735,7 @@ class MBartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, u
@unittest.skip(reason="Decoder cannot retain gradients")
def test_retain_grad_hidden_states_attentions(self):
return
@unittest.skip(reason="Decoder cannot retain gradients")
def test_flex_attention_with_grads(self):
return

View File

@@ -728,6 +728,9 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# force eager attention to support output attentions
config._attn_implementation = "eager"
for model_class in self.all_model_classes:
self.check_musicgen_model_output_attentions(model_class, config, **inputs_dict)
self.check_musicgen_model_output_attentions_from_config(model_class, config, **inputs_dict)
@@ -805,6 +808,9 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
config.text_encoder.output_attentions = True
config.decoder.output_attentions = True
# force eager attention to support output attentions
config._attn_implementation = "eager"
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
@@ -1036,30 +1042,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
@mark.flash_attn_test
@slow
def test_flash_attn_2_conversion(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
).to(torch_device)
for _, module in model.named_modules():
if "FlashAttention" in module.__class__.__name__:
return
self.assertTrue(False, "FlashAttention2 modules not found in model")
self.skipTest(reason="Musicgen doesn't use the MusicgenFlashAttention2 class method.")
@require_torch_sdpa
@require_torch_gpu
@@ -1234,18 +1217,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
self.assertTrue(model_eager.decoder.config._attn_implementation == "eager")
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
def test_requires_grad_with_frozen_encoders(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:
@@ -1276,6 +1247,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
def test_generation_tester_mixin_inheritance(self):
pass
@unittest.skip(reason=("MusicGen has a set of composite models which might not have SDPA themselves, e.g. T5."))
def test_sdpa_can_compile_dynamic(self):
pass
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
"""Produces a series of 'bip bip' sounds at a given frequency."""

View File

@@ -731,6 +731,9 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# force eager attention to support output attentions
config._attn_implementation = "eager"
for model_class in self.all_model_classes:
self.check_musicgen_melody_model_output_attentions(model_class, config, **inputs_dict)
self.check_musicgen_melody_model_output_attentions_from_config(model_class, config, **inputs_dict)
@@ -807,6 +810,9 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
config.text_encoder.output_attentions = True
config.decoder.output_attentions = True
# force eager attention to support output attentions
config._attn_implementation = "eager"
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
@@ -1036,30 +1042,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
@mark.flash_attn_test
@slow
def test_flash_attn_2_conversion(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
).to(torch_device)
for _, module in model.named_modules():
if "FlashAttention" in module.__class__.__name__:
return
self.assertTrue(False, "FlashAttention2 modules not found in model")
self.skipTest(reason="MusicgenMelody doesn't use the MusicgenMelodyFlashAttention2 class method.")
@require_torch_sdpa
@require_torch_gpu
@@ -1234,18 +1217,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
self.assertTrue(model_eager.decoder.config._attn_implementation == "eager")
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
def test_requires_grad_with_frozen_encoders(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:
@@ -1276,6 +1247,10 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def test_generation_tester_mixin_inheritance(self):
pass
@unittest.skip(reason=("MusicGen has a set of composite models which might not have SDPA themselves, e.g. T5."))
def test_sdpa_can_compile_dynamic(self):
pass
# Copied from tests.models.musicgen.test_modeling_musicgen.get_bip_bip
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):

View File

@@ -480,7 +480,7 @@ class PatchTSMixerModelIntegrationTests(unittest.TestCase):
)
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor([[[[-0.9106]],[[1.5326]],[[-0.8245]],[[0.7439]],[[-0.7830]],[[2.6256]],[[-0.6485]],]],device=torch_device) # fmt: skip
expected_slice = torch.tensor([[[-0.9106]],[[1.5326]],[[-0.8245]],[[0.7439]],[[-0.7830]],[[2.6256]],[[-0.6485]],],device=torch_device) # fmt: skip
torch.testing.assert_close(output[0, :7, :1, :1], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
def test_forecasting_head(self):

View File

@@ -597,3 +597,7 @@ class PegasusStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
@unittest.skip(reason="Decoder cannot keep gradients")
def test_retain_grad_hidden_states_attentions(self):
return
@unittest.skip(reason="Decoder cannot keep gradients")
def test_flex_attention_with_grads():
return

View File

@@ -590,7 +590,7 @@ class PegasusXModelIntegrationTests(unittest.TestCase):
self.assertEqual(output.shape, expected_shape)
# change to expected output here
expected_slice = torch.tensor(
[[0.0702, -0.1552, 0.1192], [0.0836, -0.1848, 0.1304], [0.0673, -0.1686, 0.1045]], device=torch_device
[[[0.0702, -0.1552, 0.1192], [0.0836, -0.1848, 0.1304], [0.0673, -0.1686, 0.1045]]], device=torch_device
)
torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
@@ -608,7 +608,8 @@ class PegasusXModelIntegrationTests(unittest.TestCase):
self.assertEqual(output.shape, expected_shape)
# change to expected output here
expected_slice = torch.tensor(
[[0.0, 9.5705185, 1.5897303], [0.0, 9.833374, 1.5828674], [0.0, 10.429961, 1.5643371]], device=torch_device
[[[0.0, 9.5705185, 1.5897303], [0.0, 9.833374, 1.5828674], [0.0, 10.429961, 1.5643371]]],
device=torch_device,
)
torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
@@ -635,8 +636,7 @@ class PegasusXModelIntegrationTests(unittest.TestCase):
batch_input,
max_length=512,
padding="max_length",
truncation_strategy="only_first",
truncation=True,
truncation="only_first",
return_tensors="pt",
)
@@ -872,3 +872,7 @@ class PegasusXStandaloneDecoderModelTest(ModelTesterMixin, unittest.TestCase):
@unittest.skip(reason="Decoder cannot keep gradients")
def test_retain_grad_hidden_states_attentions(self):
return
@unittest.skip(reason="Decoder cannot keep gradients")
def test_flex_attention_with_grads():
return

View File

@@ -677,3 +677,7 @@ class PLBartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
@unittest.skip(reason="Decoder cannot keep gradients")
def test_retain_grad_hidden_states_attentions(self):
return
@unittest.skip(reason="Decoder cannot keep gradients")
def test_flex_attention_with_grads():
return

View File

@@ -342,6 +342,9 @@ class SEWModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
config.output_hidden_states = True
config.output_attentions = True
# force eager attention to support output attentions
config._attn_implementation = "eager"
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)

View File

@@ -300,6 +300,10 @@ class EncoderDecoderMixin:
input_features=None,
**kwargs,
):
# force eager attention to support output attentions
config._attn_implementation = "eager"
decoder_config._attn_implementation = "eager"
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]

View File

@@ -383,6 +383,9 @@ class UniSpeechRobustModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.T
config.output_hidden_states = True
config.output_attentions = True
# force eager attention to support output attentions
config._attn_implementation = "eager"
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)

View File

@@ -423,6 +423,9 @@ class UniSpeechSatModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
config.output_hidden_states = True
config.output_attentions = True
# force eager attention to support output attentions
config._attn_implementation = "eager"
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
@@ -632,6 +635,9 @@ class UniSpeechSatRobustModelTest(ModelTesterMixin, unittest.TestCase):
config.output_hidden_states = True
config.output_attentions = True
# force eager attention to support output attentions
config._attn_implementation = "eager"
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)

View File

@@ -246,6 +246,10 @@ class EncoderDecoderMixin:
pixel_values=None,
**kwargs,
):
# force eager attention to support output attentions
config._attn_implementation = "eager"
decoder_config._attn_implementation = "eager"
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
@@ -480,6 +484,10 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
pixel_values=None,
**kwargs,
):
# force eager attention to support output attentions
config._attn_implementation = "eager"
decoder_config._attn_implementation = "eager"
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
@@ -670,6 +678,10 @@ class Swin2BartModelTest(EncoderDecoderMixin, unittest.TestCase):
pixel_values=None,
**kwargs,
):
# force eager attention to support output attentions
config._attn_implementation = "eager"
decoder_config._attn_implementation = "eager"
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
@@ -807,6 +819,10 @@ class LayoutLMv32TrOCR(EncoderDecoderMixin, unittest.TestCase):
labels=None,
**kwargs,
):
# force eager attention to support output attentions
config._attn_implementation = "eager"
decoder_config._attn_implementation = "eager"
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
@@ -929,6 +945,10 @@ class VIT2GPT2Test(EncoderDecoderMixin, unittest.TestCase):
labels=None,
**kwargs,
):
# force eager attention to support output attentions
config._attn_implementation = "eager"
decoder_config._attn_implementation = "eager"
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
@@ -1047,6 +1067,10 @@ class Donut2GPT2Test(EncoderDecoderMixin, unittest.TestCase):
labels=None,
**kwargs,
):
# force eager attention to support output attentions
config._attn_implementation = "eager"
decoder_config._attn_implementation = "eager"
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]

View File

@@ -570,6 +570,9 @@ class Wav2Vec2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
config.output_hidden_states = True
config.output_attentions = True
# force eager attention to support output attentions
config._attn_implementation = "eager"
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
@@ -917,6 +920,9 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
config.output_hidden_states = True
config.output_attentions = True
# force eager attention to support output attentions
config._attn_implementation = "eager"
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)