Refactor Attention implementation for ViT-based models (#36545)

* Refactor vit attention

* Refactor ViT-based models

* 🚨🚨🚨 Fix prefix for DPT

* Update params order

* trigger tests

* Fix Dinov2 attention

* Fix DPT attention impl propagation for backbone config

* Common test fix: config is modif. inplace - avoid it

* view->reshape

* Fixup

* Fixup

* Enable IJepa FA2

* Add FA2 in corresponding model docs
This commit is contained in:
Pavel Iakubovskii
2025-03-20 15:15:01 +00:00
committed by GitHub
parent 730d2a52e7
commit 66291778dd
35 changed files with 932 additions and 975 deletions

View File

@@ -255,6 +255,10 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Inductor error for dynamic shape")
def test_sdpa_can_compile_dynamic(self):
pass
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@@ -15,14 +15,24 @@
"""Testing suite for the PyTorch VideoMAE model."""
import copy
import tempfile
import unittest
import numpy as np
from huggingface_hub import hf_hub_download
from pytest import mark
from transformers import VideoMAEConfig
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.testing_utils import (
is_flaky,
require_flash_attn,
require_torch,
require_torch_gpu,
require_vision,
slow,
torch_device,
)
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@@ -338,6 +348,59 @@ class VideoMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
check_hidden_states_output(inputs_dict, config, model_class)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
@is_flaky()
def test_flash_attn_2_inference_equivalence(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
inputs_dict["pixel_values"] = inputs_dict["pixel_values"].to(torch.bfloat16)
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
outputs = model(**inputs_dict, output_hidden_states=True)
outputs_fa = model_fa(**inputs_dict, output_hidden_states=True)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
# check with inference + dropout
model.train()
_ = model_fa(**inputs_dict)
@unittest.skip("Not applicable for VideoMAE")
def test_flash_attn_2_inference_equivalence_right_padding(self):
pass
# We will verify our results on a video of eating spaghetti
# Frame indices used: [164 168 172 176 181 185 189 193 198 202 206 210 215 219 223 227]

View File

@@ -19,9 +19,18 @@ import tempfile
import unittest
import numpy as np
from pytest import mark
from transformers import ViTMAEConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.testing_utils import (
is_flaky,
require_flash_attn,
require_torch,
require_torch_gpu,
require_vision,
slow,
torch_device,
)
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@@ -269,6 +278,63 @@ class ViTMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
model = ViTMAEModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
@is_flaky()
def test_flash_attn_2_inference_equivalence(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
inputs_dict["pixel_values"] = inputs_dict["pixel_values"].to(torch.bfloat16)
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
# ForPretraining model has random `noise` -> need to set seed
# to make the test deterministic
torch.manual_seed(12345)
outputs = model(**inputs_dict, output_hidden_states=True)
torch.manual_seed(12345)
outputs_fa = model_fa(**inputs_dict, output_hidden_states=True)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
# check with inference + dropout
model.train()
_ = model_fa(**inputs_dict)
@unittest.skip("Not applicable for VideoMAE")
def test_flash_attn_2_inference_equivalence_right_padding(self):
pass
# We will verify our results on an image of cute cats
def prepare_img():

View File

@@ -130,7 +130,7 @@ class ConfigTester:
general_config_dict = config.to_dict()
# Iterate over all sub_configs if there are any and load them with their own classes
sub_configs = self.config_class.sub_configs
sub_configs = general_config_loaded.sub_configs
for sub_config_key, sub_class in sub_configs.items():
if sub_class.__name__ == "AutoConfig":
sub_class = sub_class.for_model(**general_config_dict[sub_config_key]).__class__

View File

@@ -315,8 +315,6 @@ class ModelTesterMixin:
return inputs_dict
def test_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def check_save_load(out1, out2):
# make sure we don't have nans
out_2 = out2.cpu().numpy()
@@ -330,6 +328,7 @@ class ModelTesterMixin:
self.assertLessEqual(max_diff, 1e-5)
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
model.to(torch_device)
model.eval()
@@ -508,16 +507,16 @@ class ModelTesterMixin:
@is_flaky(description="low likelihood of failure, reason not yet discovered")
def test_save_load_fast_init_from_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING:
self.skipTest(reason=f"{config.__class__.__name__} not in MODEL_MAPPING")
base_class = MODEL_MAPPING[config.__class__]
if isinstance(base_class, tuple):
base_class = base_class[0]
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING:
self.skipTest(reason=f"{config.__class__.__name__} not in MODEL_MAPPING")
base_class = MODEL_MAPPING[config.__class__]
if isinstance(base_class, tuple):
base_class = base_class[0]
if model_class == base_class:
continue
@@ -2228,9 +2227,9 @@ class ModelTesterMixin:
def test_correct_missing_keys(self):
if not self.test_missing_keys:
self.skipTest(reason="test_missing_keys is set to `False`")
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
base_model_prefix = model.base_model_prefix
@@ -2287,8 +2286,8 @@ class ModelTesterMixin:
@require_safetensors
def test_can_use_safetensors(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model_tied = model_class(config)
with tempfile.TemporaryDirectory() as d:
try:
@@ -2323,9 +2322,9 @@ class ModelTesterMixin:
)
def test_load_save_without_tied_weights(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config.tie_word_embeddings = False
for model_class in self.all_model_classes:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config.tie_word_embeddings = False
model = model_class(config)
with tempfile.TemporaryDirectory() as d:
model.save_pretrained(d)
@@ -2373,8 +2372,8 @@ class ModelTesterMixin:
)
def test_model_weights_reload_no_missing_tied_weights(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)