Refactor attention for SigLIP based models (#36981)

* Update Siglip attention implementation

* Update tests for Siglip

* Remove one level of indentation

* Update test to be more specific

* Fixup

* Idefics2

* Idefics3

* Emu3

* SmolVLM

* Phi4 (just init small update)

* Idefics2 (test fix)

* Update siglip2 tests

* Update eager

* trigger

* Clean up

* Transfer inputs to device in test

* Fixing test

* Fixing test

* Revert contiguous

* Remove unused is_flash_attn_2_available

* Move flaky to specific models
This commit is contained in:
Pavel Iakubovskii
2025-04-01 14:37:25 +01:00
committed by GitHub
parent 24e311f42b
commit 3249c5dc15
12 changed files with 563 additions and 1642 deletions

View File

@@ -344,17 +344,15 @@ class Idefics2ModelTest(ModelTesterMixin, unittest.TestCase):
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
vision_attn = None if model.vision_model._supports_sdpa else "eager"
perceiver_attn = None if model.connector.perceiver_resampler._supports_sdpa else "eager"
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model_sdpa.connector.perceiver_resampler.config._attn_implementation == perceiver_attn)
self.assertTrue(model_sdpa.vision_model.config._attn_implementation == "sdpa")
self.assertTrue(model_sdpa.connector.perceiver_resampler.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager")
self.assertTrue(model_sdpa.connector.perceiver_resampler.config._attn_implementation == "eager")
self.assertTrue(model_eager.connector.perceiver_resampler.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__