Ignore non-causal mask in more cases with SDPA (#30138)
* update non-causal mask for sdpa * add test * update docstrings * add one more test * fix cross attention bug * gentler atol/rtol
This commit is contained in:
@@ -16,7 +16,7 @@ import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import BertConfig, is_torch_available
|
||||
from transformers import AutoTokenizer, BertConfig, is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
@@ -747,3 +747,36 @@ class BertModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))
|
||||
|
||||
def test_sdpa_ignored_mask(self):
|
||||
pkv = []
|
||||
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel", attn_implementation="eager")
|
||||
model_sdpa = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel", attn_implementation="sdpa")
|
||||
|
||||
model = model.eval()
|
||||
model_sdpa = model_sdpa.eval()
|
||||
|
||||
for _ in range(model.config.num_hidden_layers):
|
||||
num_heads = model.config.num_attention_heads
|
||||
head_dim = model.config.hidden_size // model.config.num_attention_heads
|
||||
pkv.append([torch.rand(1, num_heads, 3, head_dim), torch.rand(1, num_heads, 3, head_dim)])
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BertModel")
|
||||
inp = tokenizer("I am in Paris and", return_tensors="pt")
|
||||
|
||||
del inp["attention_mask"]
|
||||
|
||||
with torch.no_grad():
|
||||
res_eager = model(**inp)
|
||||
res_sdpa = model_sdpa(**inp)
|
||||
self.assertTrue(
|
||||
torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4)
|
||||
)
|
||||
|
||||
# Case where query length != kv_length.
|
||||
res_eager = model(**inp, past_key_values=pkv)
|
||||
res_sdpa = model_sdpa(**inp, past_key_values=pkv)
|
||||
self.assertTrue(
|
||||
torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user