Refactor Attention implementation for ViT-based models (#36545)

* Refactor vit attention

* Refactor ViT-based models

* 🚨🚨🚨 Fix prefix for DPT

* Update params order

* trigger tests

* Fix Dinov2 attention

* Fix DPT attention impl propagation for backbone config

* Common test fix: config is modif. inplace - avoid it

* view->reshape

* Fixup

* Fixup

* Enable IJepa FA2

* Add FA2 in corresponding model docs
This commit is contained in:
Pavel Iakubovskii
2025-03-20 15:15:01 +00:00
committed by GitHub
parent 730d2a52e7
commit 66291778dd
35 changed files with 932 additions and 975 deletions

View File

@@ -255,6 +255,10 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Inductor error for dynamic shape")
def test_sdpa_can_compile_dynamic(self):
pass
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@@ -15,14 +15,24 @@
"""Testing suite for the PyTorch VideoMAE model."""
import copy
import tempfile
import unittest
import numpy as np
from huggingface_hub import hf_hub_download
from pytest import mark
from transformers import VideoMAEConfig
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.testing_utils import (
is_flaky,
require_flash_attn,
require_torch,
require_torch_gpu,
require_vision,
slow,
torch_device,
)
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@@ -338,6 +348,59 @@ class VideoMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
check_hidden_states_output(inputs_dict, config, model_class)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
@is_flaky()
def test_flash_attn_2_inference_equivalence(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
inputs_dict["pixel_values"] = inputs_dict["pixel_values"].to(torch.bfloat16)
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
outputs = model(**inputs_dict, output_hidden_states=True)
outputs_fa = model_fa(**inputs_dict, output_hidden_states=True)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
# check with inference + dropout
model.train()
_ = model_fa(**inputs_dict)
@unittest.skip("Not applicable for VideoMAE")
def test_flash_attn_2_inference_equivalence_right_padding(self):
pass
# We will verify our results on a video of eating spaghetti
# Frame indices used: [164 168 172 176 181 185 189 193 198 202 206 210 215 219 223 227]

View File

@@ -19,9 +19,18 @@ import tempfile
import unittest
import numpy as np
from pytest import mark
from transformers import ViTMAEConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.testing_utils import (
is_flaky,
require_flash_attn,
require_torch,
require_torch_gpu,
require_vision,
slow,
torch_device,
)
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@@ -269,6 +278,63 @@ class ViTMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
model = ViTMAEModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
@is_flaky()
def test_flash_attn_2_inference_equivalence(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
inputs_dict["pixel_values"] = inputs_dict["pixel_values"].to(torch.bfloat16)
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
# ForPretraining model has random `noise` -> need to set seed
# to make the test deterministic
torch.manual_seed(12345)
outputs = model(**inputs_dict, output_hidden_states=True)
torch.manual_seed(12345)
outputs_fa = model_fa(**inputs_dict, output_hidden_states=True)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
# check with inference + dropout
model.train()
_ = model_fa(**inputs_dict)
@unittest.skip("Not applicable for VideoMAE")
def test_flash_attn_2_inference_equivalence_right_padding(self):
pass
# We will verify our results on an image of cute cats
def prepare_img():