From 608fa5496cdb3199c8c12523f01cdb73fe1765b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Fri, 8 Mar 2024 13:53:17 +0100 Subject: [PATCH] Make sliding window size inclusive in eager attention (#29519) * Make sliding window size inclusive in eager attention * Fix tests --- src/transformers/modeling_attn_mask_utils.py | 6 +++--- tests/test_modeling_utils.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index faae0d763f..8ad68f39db 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -164,10 +164,10 @@ class AttentionMaskConverter: # add lower triangular sliding window mask if necessary if sliding_window is not None: - diagonal = past_key_values_length - sliding_window + 1 + diagonal = past_key_values_length - sliding_window - 1 - context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal) - mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) + context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal) + mask.masked_fill_(context_mask, torch.finfo(dtype).min) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index d0db5031e8..87b933425f 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1673,7 +1673,7 @@ class AttentionMaskTester(unittest.TestCase): def compute_num_context_mask(self, kv_len, context, q_len): # This function computes the # of attention tokens that are added for # the sliding window - c_mask_len = kv_len - context + c_mask_len = kv_len - context - 1 num_mask_triangle = c_mask_len * (c_mask_len + 1) // 2 cut_mask_len = max(c_mask_len - q_len, 0) num_cut_mask = cut_mask_len * (cut_mask_len + 1) // 2