[AttentionMaskConverter] ]Fix-mask-inf (#27114)
* fix? * actual fix * fixups * add dataclass to the attention mask converter * refine testing suite * make sure there are no overflows * update the test
This commit is contained in:
@@ -1266,6 +1266,9 @@ class AttentionMaskTester(unittest.TestCase):
|
||||
|
||||
assert mask_4d.shape == (bsz, 1, q_len, kv_len)
|
||||
|
||||
# make sure there are no overflows
|
||||
assert mask_4d.min() != float("-inf")
|
||||
|
||||
context = mask_converter.sliding_window
|
||||
if mask_converter.is_causal and context is None:
|
||||
# k * (k+1) / 2 tokens are masked in triangualar masks
|
||||
@@ -1341,6 +1344,9 @@ class AttentionMaskTester(unittest.TestCase):
|
||||
self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
|
||||
self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
|
||||
|
||||
# check that the mask does not overflow on causal masked tokens
|
||||
self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 0), (1, 0), (1, 1)])
|
||||
|
||||
def test_2d_to_4d(self):
|
||||
mask_converter = AttentionMaskConverter(is_causal=False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user