Refactor attention for SigLIP based models (#36981)
* Update Siglip attention implementation * Update tests for Siglip * Remove one level of indentation * Update test to be more specific * Fixup * Idefics2 * Idefics3 * Emu3 * SmolVLM * Phi4 (just init small update) * Idefics2 (test fix) * Update siglip2 tests * Update eager * trigger * Clean up * Transfer inputs to device in test * Fixing test * Fixing test * Revert contiguous * Remove unused is_flash_attn_2_available * Move flaky to specific models
This commit is contained in:
committed by
GitHub
parent
24e311f42b
commit
3249c5dc15
@@ -344,17 +344,15 @@ class Idefics2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model_sdpa = model_class.from_pretrained(tmpdirname)
|
||||
model_sdpa = model_sdpa.eval().to(torch_device)
|
||||
|
||||
vision_attn = None if model.vision_model._supports_sdpa else "eager"
|
||||
perceiver_attn = None if model.connector.perceiver_resampler._supports_sdpa else "eager"
|
||||
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||
self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn)
|
||||
self.assertTrue(model_sdpa.connector.perceiver_resampler.config._attn_implementation == perceiver_attn)
|
||||
self.assertTrue(model_sdpa.vision_model.config._attn_implementation == "sdpa")
|
||||
self.assertTrue(model_sdpa.connector.perceiver_resampler.config._attn_implementation == "sdpa")
|
||||
|
||||
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
|
||||
model_eager = model_eager.eval().to(torch_device)
|
||||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager")
|
||||
self.assertTrue(model_sdpa.connector.perceiver_resampler.config._attn_implementation == "eager")
|
||||
self.assertTrue(model_eager.connector.perceiver_resampler.config._attn_implementation == "eager")
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
|
||||
Reference in New Issue
Block a user