Ignore non-causal mask in more cases with SDPA (#30138)
* update non-causal mask for sdpa * add test * update docstrings * add one more test * fix cross attention bug * gentler atol/rtol
This commit is contained in:
@@ -413,7 +413,7 @@ def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len:
|
|||||||
`(batch_size, key_value_length)`
|
`(batch_size, key_value_length)`
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mask (`torch.Tensor` or `None`):
|
mask (`torch.Tensor`):
|
||||||
A 2D attention mask of shape `(batch_size, key_value_length)`
|
A 2D attention mask of shape `(batch_size, key_value_length)`
|
||||||
dtype (`torch.dtype`):
|
dtype (`torch.dtype`):
|
||||||
The torch dtype the created mask shall have.
|
The torch dtype the created mask shall have.
|
||||||
@@ -429,36 +429,25 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype,
|
|||||||
`(batch_size, key_value_length)`
|
`(batch_size, key_value_length)`
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mask (`torch.Tensor` or `None`):
|
mask (`torch.Tensor`):
|
||||||
A 2D attention mask of shape `(batch_size, key_value_length)`
|
A 2D attention mask of shape `(batch_size, key_value_length)`
|
||||||
dtype (`torch.dtype`):
|
dtype (`torch.dtype`):
|
||||||
The torch dtype the created mask shall have.
|
The torch dtype the created mask shall have.
|
||||||
tgt_len (`int`):
|
tgt_len (`int`):
|
||||||
The target length or query length the created mask shall have.
|
The target length or query length the created mask shall have.
|
||||||
"""
|
"""
|
||||||
batch_size, key_value_length = mask.shape
|
_, key_value_length = mask.shape
|
||||||
tgt_len = tgt_len if tgt_len is not None else key_value_length
|
tgt_len = tgt_len if tgt_len is not None else key_value_length
|
||||||
|
|
||||||
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
|
|
||||||
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
|
|
||||||
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
|
|
||||||
is_tracing = (
|
is_tracing = (
|
||||||
torch.jit.is_tracing()
|
torch.jit.is_tracing()
|
||||||
or isinstance(mask, torch.fx.Proxy)
|
or isinstance(mask, torch.fx.Proxy)
|
||||||
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows.
|
||||||
if not is_tracing and torch.all(mask == 1):
|
if not is_tracing and torch.all(mask == 1):
|
||||||
if tgt_len == 1:
|
return None
|
||||||
# For query_length == 1, causal attention and bi-directional attention are the same.
|
|
||||||
return None
|
|
||||||
elif key_value_length == tgt_len:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
# Unfortunately, for query_length > 1 and key_value_length != query_length, we can not generally ignore the attention mask, as SDPA causal mask generation
|
|
||||||
# may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
|
|
||||||
# Reference: https://github.com/pytorch/pytorch/issues/108108
|
|
||||||
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
|
||||||
else:
|
else:
|
||||||
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
||||||
|
|
||||||
|
|||||||
@@ -432,7 +432,9 @@ class BertSdpaSelfAttention(BertSelfAttention):
|
|||||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||||
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
|
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
|
||||||
# a causal mask in case tgt_len == 1.
|
# a causal mask in case tgt_len == 1.
|
||||||
is_causal = True if self.is_decoder and attention_mask is None and tgt_len > 1 else False
|
is_causal = (
|
||||||
|
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
|
||||||
|
)
|
||||||
|
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
query_layer,
|
query_layer,
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import BertConfig, is_torch_available
|
from transformers import AutoTokenizer, BertConfig, is_torch_available
|
||||||
from transformers.models.auto import get_values
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
CaptureLogger,
|
CaptureLogger,
|
||||||
@@ -747,3 +747,36 @@ class BertModelIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))
|
self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))
|
||||||
|
|
||||||
|
def test_sdpa_ignored_mask(self):
|
||||||
|
pkv = []
|
||||||
|
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel", attn_implementation="eager")
|
||||||
|
model_sdpa = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel", attn_implementation="sdpa")
|
||||||
|
|
||||||
|
model = model.eval()
|
||||||
|
model_sdpa = model_sdpa.eval()
|
||||||
|
|
||||||
|
for _ in range(model.config.num_hidden_layers):
|
||||||
|
num_heads = model.config.num_attention_heads
|
||||||
|
head_dim = model.config.hidden_size // model.config.num_attention_heads
|
||||||
|
pkv.append([torch.rand(1, num_heads, 3, head_dim), torch.rand(1, num_heads, 3, head_dim)])
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BertModel")
|
||||||
|
inp = tokenizer("I am in Paris and", return_tensors="pt")
|
||||||
|
|
||||||
|
del inp["attention_mask"]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
res_eager = model(**inp)
|
||||||
|
res_sdpa = model_sdpa(**inp)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Case where query length != kv_length.
|
||||||
|
res_eager = model(**inp, past_key_values=pkv)
|
||||||
|
res_sdpa = model_sdpa(**inp, past_key_values=pkv)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4)
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user