Fix Causality Handling in Flash Attention to Support Bidirectional Attention (#39707)

Fix the is_causal logic to enable bidirectional attention

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
LucasChan
2025-08-13 00:16:28 +08:00
committed by GitHub
parent 83dbebc429
commit 0ce24f5a88

View File

@@ -58,8 +58,10 @@ def flash_attention_forward(
else:
target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
# FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
kwargs.pop("is_causal", None)
# Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
is_causal = kwargs.pop("is_causal", None)
if is_causal is None:
is_causal = module.is_causal
attn_output = _flash_attention_forward(
query,
@@ -67,7 +69,7 @@ def flash_attention_forward(
value,
attention_mask,
query_length=seq_len,
is_causal=module.is_causal,
is_causal=is_causal,
dropout=dropout,
softmax_scale=scaling,
sliding_window=sliding_window,