🔴[Attention] Attention refactor for Whisper-based models (#38235)

* start refactoring whisper

* revert for now

* first step

* carry over attn fixes

* check if this works

* whisper has an off by one somewhere - cutting mask in any interface

* make it based on interface

* remove some tests that were skipped but now work

* some fixes for whisper tests

* interface changes

* change the order of fix

* some attention adjustments for eager + TP

* fix scaling

* mask changes

* why does whisper contain those extra seq lens?

* fix from config for fa2 as input_ids is invalid

* fix another test

* another fix

* disable flex attn due to compile issues

* copies and refactor for qwen audio since it somewhat relies on whisper

* fix scaling and smaller things

* retrigger

* new new interface version + more fixups

* adjust qwen

* add comment

* forgot this one

* change copies as whisper cuts on the mask

* add guard

* add flex attention

* switch to new mask function + add skips for torchscript

* remove old api with cache position

* last changes?

* trigger ci
This commit is contained in:
Anton Vlasjuk
2025-05-28 13:32:38 +02:00
committed by GitHub
parent 565a0052ed
commit badc71b9f6
9 changed files with 200 additions and 646 deletions

View File

@@ -4268,24 +4268,28 @@ class ModelTesterMixin:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# TODO: to change it in the future with other relevant auto classes
fa2_model = model_class._from_config(
config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
config, attn_implementation="flash_attention_2", torch_dtype=torch.float16
).to(torch_device)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
dummy_input = inputs_dict[fa2_model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
if config.is_encoder_decoder:
if fa2_model.config.is_encoder_decoder:
dummy_decoder_input_ids = inputs_dict["decoder_input_ids"]
dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"]
_ = fa2_model(
input_ids=dummy_input,
dummy_input,
attention_mask=dummy_attention_mask,
decoder_input_ids=dummy_input.clone(),
decoder_attention_mask=dummy_attention_mask.clone(),
decoder_input_ids=dummy_decoder_input_ids,
decoder_attention_mask=dummy_decoder_attention_mask,
)
else:
_ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
_ = fa2_model(dummy_input, attention_mask=dummy_attention_mask)
with tempfile.TemporaryDirectory() as tmpdirname:
fa2_model.save_pretrained(tmpdirname)