[fix] sliding window attention mask (#38045)
* fix sliding attn * make style * Update tests/test_modeling_common.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * no a second throught, should default to `True` fo BC --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
committed by
GitHub
parent
555715f418
commit
0a52bd2403
@@ -4323,6 +4323,45 @@ class ModelTesterMixin:
|
||||
|
||||
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
|
||||
|
||||
def test_sliding_window_mask(self):
|
||||
"""Tests that we can control the sliding window attention behavior of a model."""
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model does not support output_attentions")
|
||||
|
||||
if not (hasattr(config, "sliding_window") and hasattr(config, "use_sliding_window")):
|
||||
self.skipTest(reason="Model does not support sliding window mask")
|
||||
|
||||
seq_len = self.model_tester.seq_length
|
||||
batch_size = self.model_tester.batch_size
|
||||
sliding_window = 3 # set to arbitrary small number
|
||||
|
||||
sliding_mask = torch.zeros((seq_len, seq_len), dtype=torch.bool)
|
||||
for i in range(seq_len):
|
||||
start = max(0, i - sliding_window + 1)
|
||||
sliding_mask[i, start : i + 1] = True
|
||||
sliding_mask = sliding_mask.to(torch_device)
|
||||
|
||||
config.sliding_window = sliding_window
|
||||
inputs["attention_mask"] = torch.ones(batch_size, seq_len).to(torch.int64).to(torch_device)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# Set sliding window to `True` and check that all tokens beyond window size are masked
|
||||
model.config.use_sliding_window = True
|
||||
attentions = model(**inputs, output_attentions=True).attentions
|
||||
for layer_attention in attentions:
|
||||
self.assertTrue((layer_attention[:, :, ~sliding_mask] == 0).all().item())
|
||||
|
||||
# Set sliding window to `False` while keeping `sliding_window=3`
|
||||
# Check that all tokens beyond window size are not masked
|
||||
model.config.use_sliding_window = False
|
||||
attentions_not_sliding = model(**inputs, output_attentions=True).attentions
|
||||
for layer_attention in attentions_not_sliding:
|
||||
self.assertFalse((layer_attention[:, :, ~sliding_mask] == 0).all().item())
|
||||
|
||||
def test_custom_4d_attention_mask(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
Reference in New Issue
Block a user