🚨🚨[core] Completely rewrite the masking logic for all attentions (#37866)

* start

* start having a clean 4d mask primitive

* Update mask_utils.py

* Update mask_utils.py

* switch name

* Update masking_utils.py

* add a new AttentionMask tensor class

* fix import

* nits

* fixes

* use full and quandrants

* general sdpa mask for all caches

* style

* start some tests

* tests with sliding, chunked

* add styling

* test hybrid

* Update masking_utils.py

* small temp fixes

* Update modeling_gemma2.py

* compile compatible

* Update masking_utils.py

* improve

* start making it more general

* Update masking_utils.py

* generate

* make it work with flex style primitives!

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* improve

* Update cache_utils.py

* Update masking_utils.py

* simplify - starting to look good!

* Update masking_utils.py

* name

* Update masking_utils.py

* style

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* small fix for flex

* flex compile

* FA2

* Update masking_utils.py

* Escape for TGI/vLLM!

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* General case without cache

* rename

* full test on llama4

* small fix for FA2 guard with chunk

* Update modeling_gemma2.py

* post rebase cleanup

* FA2 supports static cache!

* Update modeling_flash_attention_utils.py

* Update flex_attention.py

* Update masking_utils.py

* Update masking_utils.py

* Update utils.py

* override for export

* Update executorch.py

* Update executorch.py

* Update executorch.py

* Update executorch.py

* Update masking_utils.py

* Update masking_utils.py

* output attentions

* style

* Update masking_utils.py

* Update executorch.py

* Add doicstring

* Add license and put mask visualizer at the end

* Update test_modeling_common.py

* fix broken test

* Update test_modeling_gemma.py

* Update test_modeling_gemma2.py

* Use fullgraph=False with FA2

* Update utils.py

* change name

* Update masking_utils.py

* improve doc

* change name

* Update modeling_attn_mask_utils.py

* more explicit logic based on model's property

* pattern in config

* extend

* fixes

* make it better

* generalize to other test models

* fix

* Update masking_utils.py

* fix

* do not check mask equivalence if layer types are different

* executorch

* Update modeling_gemma2.py

* Update masking_utils.py

* use layer_idx instead

* adjust

* Update masking_utils.py

* test

* fix imports

* Update modeling_gemma2.py

* other test models

* Update modeling_llama4.py

* Update masking_utils.py

* improve

* simplify

* Update masking_utils.py

* typos

* typo

* fix

* Update masking_utils.py

* default DynamicCache

* remove default cache

* simplify

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* simplify

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* export

* Update executorch.py

* Update executorch.py

* Update flex_attention.py

* Update executorch.py

* upstream to modular gemma 1 & 2

* Update modular_mistral.py

* switch names

* use dict

* put it in the Layer directly

* update copy model source for mask functions

* apply so many modular (hopefully 1 shot)

* use explicite dicts for make style happy

* protect import

* check docstring

* better default in hybrid caches

* qwens

* Update modular_qwen2.py

* simplify core logic!

* Update executorch.py

* qwen3 moe

* Update masking_utils.py

* Update masking_utils.py

* simplify a lot sdpa causal skip

* Update masking_utils.py

* post-rebase

* gemma3 finally

* style

* check it before

* gemma3

* More general with newer torch

* align gemma3

* Update utils.py

* Update utils.py

* Update masking_utils.py

* Update test_modeling_common.py

* Update flex_attention.py

* Update flex_attention.py

* Update flex_attention.py

* test

* executorch

* Update test_modeling_common.py

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* Update executorch.py

* Update test_modeling_common.py

* fix copies

* device

* sdpa can be used without mask -> pass the torchscript tests in this case

* Use enum for check

* revert enum and add check instead

* remove broken test

* cohere2

* some doc & reorganize the Interface

* Update tensor_parallel.py

* Update tensor_parallel.py

* doc and dummy

* Update test_modeling_paligemma2.py

* Update modeling_falcon_h1.py

* Update masking_utils.py

* executorch patch

* style

* CIs

* use register in executorch

* final comments!

---------

Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
Cyril Vallez
2025-05-22 11:38:26 +02:00
committed by GitHub
parent f8630c778c
commit 163138a911
129 changed files with 2976 additions and 6800 deletions

View File

@@ -13,7 +13,6 @@
# limitations under the License.
"""Testing suite for the PyTorch Gemma model."""
import tempfile
import unittest
import pytest
@@ -23,7 +22,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_to
from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import (
cleanup,
is_flaky,
require_bitsandbytes,
require_flash_attn,
require_read_token,
@@ -303,39 +301,45 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest(reason="Gemma flash attention does not support right padding")
@require_torch_sdpa
@require_torch_accelerator
def test_sdpa_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_sdpa:
self.skipTest(reason="Model does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config).to(torch_device)
dummy_input = inputs_dict[model_class.main_input_name].to(torch_device)
model.config._attn_implementation = "sdpa"
states_sdpa = model(dummy_input, output_hidden_states=True).hidden_states[-1]
model.config._attn_implementation = "eager"
states_eager = model(dummy_input, output_hidden_states=True).hidden_states[-1]
torch.testing.assert_close(states_sdpa, states_eager, atol=1e-5, rtol=1e-5)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@is_flaky()
@slow
def test_flash_attn_2_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(reason="Model does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
model = model_class(config).to(device=torch_device, dtype=torch.float16)
dummy_input = inputs_dict[model_class.main_input_name].to(torch_device)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model.config._attn_implementation = "flash_attention_2"
states_sdpa = model(dummy_input, output_hidden_states=True).hidden_states[1]
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
model.to(torch_device)
model.config._attn_implementation = "eager"
states_eager = model(dummy_input, output_hidden_states=True).hidden_states[1]
dummy_input = inputs_dict[model_class.main_input_name]
dummy_input = dummy_input.to(torch_device)
outputs = model(dummy_input, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
logits = outputs.hidden_states[-1]
logits_fa = outputs_fa.hidden_states[-1]
# gemma flash attention 2 needs a high tolerance
assert torch.allclose(logits_fa, logits, atol=3e-3)
# Here we use higher tolerance and the output of the 2nd layer because otherwise small diffs add-up
torch.testing.assert_close(states_sdpa, states_eager, atol=1e-3, rtol=1e-3)
@slow

View File

@@ -154,6 +154,10 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
def test_eager_matches_fa2_generate(self):
pass
@unittest.skip("Gemma2 eager/FA2 attention outputs are expected to be different")
def test_flash_attn_2_equivalence(self):
pass
@slow
@require_torch_accelerator

View File

@@ -25,7 +25,6 @@ from transformers import (
AutoTokenizer,
Gemma3Config,
Gemma3TextConfig,
GenerationConfig,
is_torch_available,
)
from transformers.testing_utils import (
@@ -635,46 +634,6 @@ class Gemma3IntegrationTest(unittest.TestCase):
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
def test_generation_beyond_sliding_window_with_generation_config(self):
"""
Similar to `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684
-- ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
"""
model_id = "google/gemma-3-1b-it"
attn_implementation = "sdpa"
input_text = [
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
"A list of colors: red, blue", # This will almost all be padding tokens
]
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
).to(torch_device)
# Make sure prefill is larger than sliding window
input_size = inputs.input_ids.shape[-1]
self.assertGreater(input_size, model.config.sliding_window)
generation_config = GenerationConfig(max_new_tokens=5, min_new_tokens=5)
out = model.generate(**inputs, generation_config=generation_config)
out = model.generate(**inputs, generation_config=generation_config, do_sample=False)[:, input_size:]
output_text = tokenizer.batch_decode(out)
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
# Generation works beyond sliding window
self.assertGreater(out.shape[1], model.config.sliding_window)
self.assertEqual(out.shape[1], input_size + 5)
# Note: Auto-inheritance only works for models saved starting from 4.50.0
model.generation_config.transformers_version = "4.49.0"
with self.assertRaises(RuntimeError): # errors out because it is not using hybrid cache
out = model.generate(**inputs, generation_config=generation_config)
def test_export_text_only_with_hybrid_cache(self):
if not is_torch_greater_or_equal("2.6.0"):
self.skipTest(reason="This test requires torch >= 2.6 to run.")

View File

@@ -26,6 +26,7 @@ from transformers import (
is_vision_available,
)
from transformers.testing_utils import (
is_flaky,
require_torch,
torch_device,
)
@@ -381,3 +382,8 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
@unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache")
def test_generate_with_static_cache(self):
pass
@pytest.mark.generate
@is_flaky
def test_generate_compile_model_forward(self):
super().test_generate_compile_model_forward()

View File

@@ -1172,25 +1172,10 @@ class ModelTesterMixin:
traced_model = torch.jit.trace(model, example_inputs, check_trace=False)
else:
main_input = inputs[main_input_name]
if model.config._attn_implementation == "sdpa":
trace_input = {main_input_name: main_input}
if "attention_mask" in inputs:
trace_input["attention_mask"] = inputs["attention_mask"]
else:
self.skipTest(reason="testing SDPA without attention_mask is not supported")
outputs = model(main_input, attention_mask=inputs["attention_mask"])
if any(isinstance(x, Cache) for x in outputs):
continue
# example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input)
else:
outputs = model(main_input)
if any(isinstance(x, Cache) for x in outputs):
continue
traced_model = torch.jit.trace(model, (main_input,))
outputs = model(main_input)
if any(isinstance(x, Cache) for x in outputs):
continue
traced_model = torch.jit.trace(model, (main_input,))
except RuntimeError:
self.fail("Couldn't trace module.")
@@ -3907,6 +3892,11 @@ class ModelTesterMixin:
self.skipTest(
"DBRX (transformers==4.40) requires a modification to support dynamic shapes with compile."
)
if getattr(config, "cache_implementation", None) == "hybrid":
self.skipTest(
"Cannot compile forward without an existing cache with Hybrid, as `torch._dynamo.mark_static_address` "
"is a forbidden call."
)
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -4346,18 +4336,31 @@ class ModelTesterMixin:
config.sliding_window = sliding_window
inputs["attention_mask"] = torch.ones(batch_size, seq_len).to(torch.int64).to(torch_device)
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
model.eval()
# Set sliding window to `True` and check that all tokens beyond window size are masked
model.config.use_sliding_window = True
config.use_sliding_window = True
config_dict = config.to_diff_dict()
if hasattr(config, "layer_types"):
del config_dict["layer_types"]
new_config = config.__class__(**config_dict)
model = model_class(new_config).to(torch_device)
model.eval()
layer_types = getattr(model.config, "layer_types", ["sliding_attention"] * config.num_hidden_layers)
attentions = model(**inputs, output_attentions=True).attentions
for layer_attention in attentions:
self.assertTrue((layer_attention[:, :, ~sliding_mask] == 0).all().item())
for layer_attention, layer_type in zip(attentions, layer_types):
if layer_type == "sliding_attention":
self.assertTrue((layer_attention[:, :, ~sliding_mask] == 0).all().item())
else:
self.assertFalse((layer_attention[:, :, ~sliding_mask] == 0).all().item())
# Set sliding window to `False` while keeping `sliding_window=3`
# Check that all tokens beyond window size are not masked
model.config.use_sliding_window = False
config.use_sliding_window = False
config_dict = config.to_diff_dict()
if hasattr(config, "layer_types"):
del config_dict["layer_types"]
new_config = config.__class__(**config_dict)
model = model_class(new_config).to(torch_device)
model.eval()
attentions_not_sliding = model(**inputs, output_attentions=True).attentions
for layer_attention in attentions_not_sliding:
self.assertFalse((layer_attention[:, :, ~sliding_mask] == 0).all().item())

View File

@@ -55,6 +55,7 @@ if is_torch_available():
convert_and_export_with_cache,
pipeline,
)
from transformers.integrations.executorch import export_with_dynamic_cache
TEST_CACHE_IMPLEMENTATIONS = [
@@ -593,22 +594,11 @@ class CacheExportIntegrationTest(unittest.TestCase):
attention_mask = inputs.attention_mask
input_ids = inputs.input_ids
past_key_values = DynamicCache()
ep = torch.export.export(
model,
(),
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": True,
},
strict=False,
)
ep = export_with_dynamic_cache(model, input_ids, attention_mask)
res = ep.module()(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
past_key_values=DynamicCache(),
use_cache=True,
)
self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)