Fix Causality Handling in Flash Attention to Support Bidirectional Attention (#39707)
Fix the is_causal logic to enable bidirectional attention Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -58,8 +58,10 @@ def flash_attention_forward(
|
||||
else:
|
||||
target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
|
||||
|
||||
# FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
|
||||
kwargs.pop("is_causal", None)
|
||||
# Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
|
||||
is_causal = kwargs.pop("is_causal", None)
|
||||
if is_causal is None:
|
||||
is_causal = module.is_causal
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query,
|
||||
@@ -67,7 +69,7 @@ def flash_attention_forward(
|
||||
value,
|
||||
attention_mask,
|
||||
query_length=seq_len,
|
||||
is_causal=module.is_causal,
|
||||
is_causal=is_causal,
|
||||
dropout=dropout,
|
||||
softmax_scale=scaling,
|
||||
sliding_window=sliding_window,
|
||||
|
||||
Reference in New Issue
Block a user