Support for SDPA for SAM models (#34110)

* feat: add support for sdpa and gradient checkpointing

* fix: ruff format

* fix: config sdpa

* fix: sdpa layer naming convention

* fix: update test_eager_matches_sdpa_inference to handle vision_hidden_states

* test: skip incompatible tests and fix loading issue with sdpa

- Updated tests to skip cases flash and dynamic compile.
- Minor adjustment to ensure correct loading of model with sdpa for dispatch test.

* style: apply Ruff formatting

* ruff fix again after rebase

* [run-slow] sam

* [run-slow] sam

* refactor: Address review comments and improve sub-config handling in SAM model tests

- Added attributes for sub_configs as per PR #34410.
- Enabled tests for configs, ensuring the composite model (SAM) has several sub-configs in the main config.
- Added class attribute _is_composite=True to the tester class
- test_sdpa_can_dispatch_composite_models added

* [run-slow] sam

* style: ruff

* [run-slow] sam

* style: ruff again ...

* [run-slow] sam
This commit is contained in:
Magnus
2024-12-17 14:46:05 +01:00
committed by GitHub
parent 747f361da1
commit 6eb00dd2f0
4 changed files with 256 additions and 31 deletions

View File

@@ -14,12 +14,13 @@
# limitations under the License.
"""Testing suite for the PyTorch SAM model."""
import tempfile
import unittest
import requests
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline
from transformers.testing_utils import cleanup, require_torch, slow, torch_device
from transformers.testing_utils import cleanup, require_torch, require_torch_sdpa, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@@ -295,6 +296,7 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_resize_embeddings = False
test_head_masking = False
test_torchscript = False
_is_composite = True
# TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working
def is_pipeline_test_to_skip(
@@ -311,22 +313,13 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def setUp(self):
self.model_tester = SamModelTester(self)
self.vision_config_tester = ConfigTester(self, config_class=SamVisionConfig, has_text_modality=False)
self.prompt_encoder_config_tester = ConfigTester(
self,
config_class=SamPromptEncoderConfig,
has_text_modality=False,
num_attention_heads=12,
num_hidden_layers=2,
)
self.mask_decoder_config_tester = ConfigTester(
self, config_class=SamMaskDecoderConfig, has_text_modality=False
common_properties = ["initializer_range"]
self.config_tester = ConfigTester(
self, config_class=SamConfig, has_text_modality=False, common_properties=common_properties
)
def test_config(self):
self.vision_config_tester.run_common_tests()
self.prompt_encoder_config_tester.run_common_tests()
self.mask_decoder_config_tester.run_common_tests()
self.config_tester.run_common_tests()
@unittest.skip(reason="SAM's vision encoder does not use inputs_embeds")
def test_inputs_embeds(self):
@@ -450,6 +443,68 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
model = SamModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_torch_sdpa
def test_sdpa_can_compile_dynamic(self):
self.skipTest(reason="SAM model can't be compiled dynamic yet")
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
"""
Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model.
This tests only by looking at layer names, as usually SDPA layers are calles "SDPAAttention".
In contrast to the above test, this one checks if the "config._attn_implamentation" is a dict after the model
is loaded, because we manually replicate requested attn implementation on each sub-config when loading.
See https://github.com/huggingface/transformers/pull/32238 for more info
The test tries to cover most general cases of composite models, VLMs with vision and text configs. Any model
that has a different set of sub-configs has to overwrite this test.
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self._is_composite:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
model_sdpa = model_sdpa.eval().to(torch_device)
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
# Root model determines SDPA support
attn_impl = "sdpa" if model._supports_sdpa else "eager"
# Check config propagation to submodels that support it
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_sdpa.vision_encoder.config._attn_implementation == attn_impl)
self.assertTrue(model_sdpa.mask_decoder.config._attn_implementation == attn_impl)
self.assertTrue(model_eager.config._attn_implementation == "eager")
self.assertTrue(model_eager.vision_encoder.config._attn_implementation == "eager")
self.assertTrue(model_eager.mask_decoder.config._attn_implementation == "eager")
# Verify SDPA/eager layer presence
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and attn_impl == "sdpa":
raise ValueError("The SDPA model should have SDPA attention layers")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
def prepare_image():
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"

View File

@@ -4202,16 +4202,20 @@ class ModelTesterMixin:
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
logits_eager = (
outputs_eager.hidden_states[-1]
if not is_encoder_decoder
else outputs_eager.decoder_hidden_states[-1]
)
logits_sdpa = (
outputs_sdpa.hidden_states[-1]
if not is_encoder_decoder
else outputs_sdpa.decoder_hidden_states[-1]
)
if hasattr(outputs_eager, "vision_hidden_states"):
logits_eager = outputs_eager.vision_hidden_states[-1]
logits_sdpa = outputs_sdpa.vision_hidden_states[-1]
else:
logits_eager = (
outputs_eager.hidden_states[-1]
if not is_encoder_decoder
else outputs_eager.decoder_hidden_states[-1]
)
logits_sdpa = (
outputs_sdpa.hidden_states[-1]
if not is_encoder_decoder
else outputs_sdpa.decoder_hidden_states[-1]
)
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
@@ -4287,6 +4291,8 @@ class ModelTesterMixin:
)
if config.model_type in ["idefics", "idefics2", "idefics3"]:
self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input")
if config.model_type in ["sam"]:
self.skipTest(reason="SAM requires an attention_mask input for relative positional embeddings")
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname: