🔴[Attention] Attention refactor for Whisper-based models (#38235)
* start refactoring whisper * revert for now * first step * carry over attn fixes * check if this works * whisper has an off by one somewhere - cutting mask in any interface * make it based on interface * remove some tests that were skipped but now work * some fixes for whisper tests * interface changes * change the order of fix * some attention adjustments for eager + TP * fix scaling * mask changes * why does whisper contain those extra seq lens? * fix from config for fa2 as input_ids is invalid * fix another test * another fix * disable flex attn due to compile issues * copies and refactor for qwen audio since it somewhat relies on whisper * fix scaling and smaller things * retrigger * new new interface version + more fixups * adjust qwen * add comment * forgot this one * change copies as whisper cuts on the mask * add guard * add flex attention * switch to new mask function + add skips for torchscript * remove old api with cache position * last changes? * trigger ci
This commit is contained in:
@@ -4268,24 +4268,28 @@ class ModelTesterMixin:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
# TODO: to change it in the future with other relevant auto classes
|
||||
fa2_model = model_class._from_config(
|
||||
config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
|
||||
config, attn_implementation="flash_attention_2", torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
|
||||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
||||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
|
||||
dummy_input = inputs_dict[fa2_model.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||||
dummy_input = dummy_input.to(torch.float16)
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
||||
|
||||
if config.is_encoder_decoder:
|
||||
if fa2_model.config.is_encoder_decoder:
|
||||
dummy_decoder_input_ids = inputs_dict["decoder_input_ids"]
|
||||
dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"]
|
||||
_ = fa2_model(
|
||||
input_ids=dummy_input,
|
||||
dummy_input,
|
||||
attention_mask=dummy_attention_mask,
|
||||
decoder_input_ids=dummy_input.clone(),
|
||||
decoder_attention_mask=dummy_attention_mask.clone(),
|
||||
decoder_input_ids=dummy_decoder_input_ids,
|
||||
decoder_attention_mask=dummy_decoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
_ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
|
||||
_ = fa2_model(dummy_input, attention_mask=dummy_attention_mask)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fa2_model.save_pretrained(tmpdirname)
|
||||
|
||||
Reference in New Issue
Block a user