Support for SDPA for SAM models (#34110)
* feat: add support for sdpa and gradient checkpointing * fix: ruff format * fix: config sdpa * fix: sdpa layer naming convention * fix: update test_eager_matches_sdpa_inference to handle vision_hidden_states * test: skip incompatible tests and fix loading issue with sdpa - Updated tests to skip cases flash and dynamic compile. - Minor adjustment to ensure correct loading of model with sdpa for dispatch test. * style: apply Ruff formatting * ruff fix again after rebase * [run-slow] sam * [run-slow] sam * refactor: Address review comments and improve sub-config handling in SAM model tests - Added attributes for sub_configs as per PR #34410. - Enabled tests for configs, ensuring the composite model (SAM) has several sub-configs in the main config. - Added class attribute _is_composite=True to the tester class - test_sdpa_can_dispatch_composite_models added * [run-slow] sam * style: ruff * [run-slow] sam * style: ruff again ... * [run-slow] sam
This commit is contained in:
@@ -14,12 +14,13 @@
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch SAM model."""
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline
|
||||
from transformers.testing_utils import cleanup, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import cleanup, require_torch, require_torch_sdpa, slow, torch_device
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -295,6 +296,7 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torchscript = False
|
||||
_is_composite = True
|
||||
|
||||
# TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working
|
||||
def is_pipeline_test_to_skip(
|
||||
@@ -311,22 +313,13 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = SamModelTester(self)
|
||||
self.vision_config_tester = ConfigTester(self, config_class=SamVisionConfig, has_text_modality=False)
|
||||
self.prompt_encoder_config_tester = ConfigTester(
|
||||
self,
|
||||
config_class=SamPromptEncoderConfig,
|
||||
has_text_modality=False,
|
||||
num_attention_heads=12,
|
||||
num_hidden_layers=2,
|
||||
)
|
||||
self.mask_decoder_config_tester = ConfigTester(
|
||||
self, config_class=SamMaskDecoderConfig, has_text_modality=False
|
||||
common_properties = ["initializer_range"]
|
||||
self.config_tester = ConfigTester(
|
||||
self, config_class=SamConfig, has_text_modality=False, common_properties=common_properties
|
||||
)
|
||||
|
||||
def test_config(self):
|
||||
self.vision_config_tester.run_common_tests()
|
||||
self.prompt_encoder_config_tester.run_common_tests()
|
||||
self.mask_decoder_config_tester.run_common_tests()
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@unittest.skip(reason="SAM's vision encoder does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
@@ -450,6 +443,68 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
model = SamModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
self.skipTest(reason="SAM model can't be compiled dynamic yet")
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
"""
|
||||
Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model.
|
||||
This tests only by looking at layer names, as usually SDPA layers are calles "SDPAAttention".
|
||||
In contrast to the above test, this one checks if the "config._attn_implamentation" is a dict after the model
|
||||
is loaded, because we manually replicate requested attn implementation on each sub-config when loading.
|
||||
See https://github.com/huggingface/transformers/pull/32238 for more info
|
||||
|
||||
The test tries to cover most general cases of composite models, VLMs with vision and text configs. Any model
|
||||
that has a different set of sub-configs has to overwrite this test.
|
||||
"""
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
if not self._is_composite:
|
||||
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
|
||||
model_sdpa = model_sdpa.eval().to(torch_device)
|
||||
|
||||
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
|
||||
model_eager = model_eager.eval().to(torch_device)
|
||||
|
||||
# Root model determines SDPA support
|
||||
attn_impl = "sdpa" if model._supports_sdpa else "eager"
|
||||
|
||||
# Check config propagation to submodels that support it
|
||||
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||
self.assertTrue(model_sdpa.vision_encoder.config._attn_implementation == attn_impl)
|
||||
self.assertTrue(model_sdpa.mask_decoder.config._attn_implementation == attn_impl)
|
||||
|
||||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
self.assertTrue(model_eager.vision_encoder.config._attn_implementation == "eager")
|
||||
self.assertTrue(model_eager.mask_decoder.config._attn_implementation == "eager")
|
||||
|
||||
# Verify SDPA/eager layer presence
|
||||
has_sdpa = False
|
||||
for name, submodule in model_sdpa.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
has_sdpa = True
|
||||
break
|
||||
|
||||
if not has_sdpa and attn_impl == "sdpa":
|
||||
raise ValueError("The SDPA model should have SDPA attention layers")
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
|
||||
def prepare_image():
|
||||
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
||||
|
||||
@@ -4202,16 +4202,20 @@ class ModelTesterMixin:
|
||||
outputs_eager = model_eager(**prepared_inputs)
|
||||
outputs_sdpa = model_sdpa(**prepared_inputs)
|
||||
|
||||
logits_eager = (
|
||||
outputs_eager.hidden_states[-1]
|
||||
if not is_encoder_decoder
|
||||
else outputs_eager.decoder_hidden_states[-1]
|
||||
)
|
||||
logits_sdpa = (
|
||||
outputs_sdpa.hidden_states[-1]
|
||||
if not is_encoder_decoder
|
||||
else outputs_sdpa.decoder_hidden_states[-1]
|
||||
)
|
||||
if hasattr(outputs_eager, "vision_hidden_states"):
|
||||
logits_eager = outputs_eager.vision_hidden_states[-1]
|
||||
logits_sdpa = outputs_sdpa.vision_hidden_states[-1]
|
||||
else:
|
||||
logits_eager = (
|
||||
outputs_eager.hidden_states[-1]
|
||||
if not is_encoder_decoder
|
||||
else outputs_eager.decoder_hidden_states[-1]
|
||||
)
|
||||
logits_sdpa = (
|
||||
outputs_sdpa.hidden_states[-1]
|
||||
if not is_encoder_decoder
|
||||
else outputs_sdpa.decoder_hidden_states[-1]
|
||||
)
|
||||
|
||||
if torch_device in ["cpu", "cuda"]:
|
||||
atol = atols[torch_device, enable_kernels, torch_dtype]
|
||||
@@ -4287,6 +4291,8 @@ class ModelTesterMixin:
|
||||
)
|
||||
if config.model_type in ["idefics", "idefics2", "idefics3"]:
|
||||
self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input")
|
||||
if config.model_type in ["sam"]:
|
||||
self.skipTest(reason="SAM requires an attention_mask input for relative positional embeddings")
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
||||
Reference in New Issue
Block a user