[AttentionMaskConverter] ]Fix-mask-inf (#27114)

* fix?

* actual fix

* fixups

* add dataclass to the attention mask converter

* refine testing suite

* make sure there are no overflows

* update the test
This commit is contained in:
Arthur
2023-11-10 15:22:43 +01:00
committed by GitHub
parent 7e9f10ac94
commit 68afca3e69
2 changed files with 31 additions and 1 deletions

View File

@@ -1266,6 +1266,9 @@ class AttentionMaskTester(unittest.TestCase):
assert mask_4d.shape == (bsz, 1, q_len, kv_len)
# make sure there are no overflows
assert mask_4d.min() != float("-inf")
context = mask_converter.sliding_window
if mask_converter.is_causal and context is None:
# k * (k+1) / 2 tokens are masked in triangualar masks
@@ -1341,6 +1344,9 @@ class AttentionMaskTester(unittest.TestCase):
self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
# check that the mask does not overflow on causal masked tokens
self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 0), (1, 0), (1, 1)])
def test_2d_to_4d(self):
mask_converter = AttentionMaskConverter(is_causal=False)