🚨🚨[core] Completely rewrite the masking logic for all attentions (#37866)

* start

* start having a clean 4d mask primitive

* Update mask_utils.py

* Update mask_utils.py

* switch name

* Update masking_utils.py

* add a new AttentionMask tensor class

* fix import

* nits

* fixes

* use full and quandrants

* general sdpa mask for all caches

* style

* start some tests

* tests with sliding, chunked

* add styling

* test hybrid

* Update masking_utils.py

* small temp fixes

* Update modeling_gemma2.py

* compile compatible

* Update masking_utils.py

* improve

* start making it more general

* Update masking_utils.py

* generate

* make it work with flex style primitives!

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* improve

* Update cache_utils.py

* Update masking_utils.py

* simplify - starting to look good!

* Update masking_utils.py

* name

* Update masking_utils.py

* style

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* small fix for flex

* flex compile

* FA2

* Update masking_utils.py

* Escape for TGI/vLLM!

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* General case without cache

* rename

* full test on llama4

* small fix for FA2 guard with chunk

* Update modeling_gemma2.py

* post rebase cleanup

* FA2 supports static cache!

* Update modeling_flash_attention_utils.py

* Update flex_attention.py

* Update masking_utils.py

* Update masking_utils.py

* Update utils.py

* override for export

* Update executorch.py

* Update executorch.py

* Update executorch.py

* Update executorch.py

* Update masking_utils.py

* Update masking_utils.py

* output attentions

* style

* Update masking_utils.py

* Update executorch.py

* Add doicstring

* Add license and put mask visualizer at the end

* Update test_modeling_common.py

* fix broken test

* Update test_modeling_gemma.py

* Update test_modeling_gemma2.py

* Use fullgraph=False with FA2

* Update utils.py

* change name

* Update masking_utils.py

* improve doc

* change name

* Update modeling_attn_mask_utils.py

* more explicit logic based on model's property

* pattern in config

* extend

* fixes

* make it better

* generalize to other test models

* fix

* Update masking_utils.py

* fix

* do not check mask equivalence if layer types are different

* executorch

* Update modeling_gemma2.py

* Update masking_utils.py

* use layer_idx instead

* adjust

* Update masking_utils.py

* test

* fix imports

* Update modeling_gemma2.py

* other test models

* Update modeling_llama4.py

* Update masking_utils.py

* improve

* simplify

* Update masking_utils.py

* typos

* typo

* fix

* Update masking_utils.py

* default DynamicCache

* remove default cache

* simplify

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* simplify

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* export

* Update executorch.py

* Update executorch.py

* Update flex_attention.py

* Update executorch.py

* upstream to modular gemma 1 & 2

* Update modular_mistral.py

* switch names

* use dict

* put it in the Layer directly

* update copy model source for mask functions

* apply so many modular (hopefully 1 shot)

* use explicite dicts for make style happy

* protect import

* check docstring

* better default in hybrid caches

* qwens

* Update modular_qwen2.py

* simplify core logic!

* Update executorch.py

* qwen3 moe

* Update masking_utils.py

* Update masking_utils.py

* simplify a lot sdpa causal skip

* Update masking_utils.py

* post-rebase

* gemma3 finally

* style

* check it before

* gemma3

* More general with newer torch

* align gemma3

* Update utils.py

* Update utils.py

* Update masking_utils.py

* Update test_modeling_common.py

* Update flex_attention.py

* Update flex_attention.py

* Update flex_attention.py

* test

* executorch

* Update test_modeling_common.py

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* Update executorch.py

* Update test_modeling_common.py

* fix copies

* device

* sdpa can be used without mask -> pass the torchscript tests in this case

* Use enum for check

* revert enum and add check instead

* remove broken test

* cohere2

* some doc & reorganize the Interface

* Update tensor_parallel.py

* Update tensor_parallel.py

* doc and dummy

* Update test_modeling_paligemma2.py

* Update modeling_falcon_h1.py

* Update masking_utils.py

* executorch patch

* style

* CIs

* use register in executorch

* final comments!

---------

Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
Cyril Vallez
2025-05-22 11:38:26 +02:00
committed by GitHub
parent f8630c778c
commit 163138a911
129 changed files with 2976 additions and 6800 deletions

View File

@@ -1172,25 +1172,10 @@ class ModelTesterMixin:
traced_model = torch.jit.trace(model, example_inputs, check_trace=False)
else:
main_input = inputs[main_input_name]
if model.config._attn_implementation == "sdpa":
trace_input = {main_input_name: main_input}
if "attention_mask" in inputs:
trace_input["attention_mask"] = inputs["attention_mask"]
else:
self.skipTest(reason="testing SDPA without attention_mask is not supported")
outputs = model(main_input, attention_mask=inputs["attention_mask"])
if any(isinstance(x, Cache) for x in outputs):
continue
# example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input)
else:
outputs = model(main_input)
if any(isinstance(x, Cache) for x in outputs):
continue
traced_model = torch.jit.trace(model, (main_input,))
outputs = model(main_input)
if any(isinstance(x, Cache) for x in outputs):
continue
traced_model = torch.jit.trace(model, (main_input,))
except RuntimeError:
self.fail("Couldn't trace module.")
@@ -3907,6 +3892,11 @@ class ModelTesterMixin:
self.skipTest(
"DBRX (transformers==4.40) requires a modification to support dynamic shapes with compile."
)
if getattr(config, "cache_implementation", None) == "hybrid":
self.skipTest(
"Cannot compile forward without an existing cache with Hybrid, as `torch._dynamo.mark_static_address` "
"is a forbidden call."
)
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -4346,18 +4336,31 @@ class ModelTesterMixin:
config.sliding_window = sliding_window
inputs["attention_mask"] = torch.ones(batch_size, seq_len).to(torch.int64).to(torch_device)
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
model.eval()
# Set sliding window to `True` and check that all tokens beyond window size are masked
model.config.use_sliding_window = True
config.use_sliding_window = True
config_dict = config.to_diff_dict()
if hasattr(config, "layer_types"):
del config_dict["layer_types"]
new_config = config.__class__(**config_dict)
model = model_class(new_config).to(torch_device)
model.eval()
layer_types = getattr(model.config, "layer_types", ["sliding_attention"] * config.num_hidden_layers)
attentions = model(**inputs, output_attentions=True).attentions
for layer_attention in attentions:
self.assertTrue((layer_attention[:, :, ~sliding_mask] == 0).all().item())
for layer_attention, layer_type in zip(attentions, layer_types):
if layer_type == "sliding_attention":
self.assertTrue((layer_attention[:, :, ~sliding_mask] == 0).all().item())
else:
self.assertFalse((layer_attention[:, :, ~sliding_mask] == 0).all().item())
# Set sliding window to `False` while keeping `sliding_window=3`
# Check that all tokens beyond window size are not masked
model.config.use_sliding_window = False
config.use_sliding_window = False
config_dict = config.to_diff_dict()
if hasattr(config, "layer_types"):
del config_dict["layer_types"]
new_config = config.__class__(**config_dict)
model = model_class(new_config).to(torch_device)
model.eval()
attentions_not_sliding = model(**inputs, output_attentions=True).attentions
for layer_attention in attentions_not_sliding:
self.assertFalse((layer_attention[:, :, ~sliding_mask] == 0).all().item())