[fix] sliding window attention mask (#38045)

* fix sliding attn

* make style

* Update tests/test_modeling_common.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* no a second throught, should default to `True` fo BC

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Raushan Turganbay
2025-05-20 11:32:19 +02:00
committed by GitHub
parent 555715f418
commit 0a52bd2403
17 changed files with 93 additions and 36 deletions

View File

@@ -4323,6 +4323,45 @@ class ModelTesterMixin:
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
def test_sliding_window_mask(self):
"""Tests that we can control the sliding window attention behavior of a model."""
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
if not self.has_attentions:
self.skipTest(reason="Model does not support output_attentions")
if not (hasattr(config, "sliding_window") and hasattr(config, "use_sliding_window")):
self.skipTest(reason="Model does not support sliding window mask")
seq_len = self.model_tester.seq_length
batch_size = self.model_tester.batch_size
sliding_window = 3 # set to arbitrary small number
sliding_mask = torch.zeros((seq_len, seq_len), dtype=torch.bool)
for i in range(seq_len):
start = max(0, i - sliding_window + 1)
sliding_mask[i, start : i + 1] = True
sliding_mask = sliding_mask.to(torch_device)
config.sliding_window = sliding_window
inputs["attention_mask"] = torch.ones(batch_size, seq_len).to(torch.int64).to(torch_device)
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
model.eval()
# Set sliding window to `True` and check that all tokens beyond window size are masked
model.config.use_sliding_window = True
attentions = model(**inputs, output_attentions=True).attentions
for layer_attention in attentions:
self.assertTrue((layer_attention[:, :, ~sliding_mask] == 0).all().item())
# Set sliding window to `False` while keeping `sliding_window=3`
# Check that all tokens beyond window size are not masked
model.config.use_sliding_window = False
attentions_not_sliding = model(**inputs, output_attentions=True).attentions
for layer_attention in attentions_not_sliding:
self.assertFalse((layer_attention[:, :, ~sliding_mask] == 0).all().item())
def test_custom_4d_attention_mask(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")