Support for SDPA for SAM models (#34110)

* feat: add support for sdpa and gradient checkpointing

* fix: ruff format

* fix: config sdpa

* fix: sdpa layer naming convention

* fix: update test_eager_matches_sdpa_inference to handle vision_hidden_states

* test: skip incompatible tests and fix loading issue with sdpa

- Updated tests to skip cases flash and dynamic compile.
- Minor adjustment to ensure correct loading of model with sdpa for dispatch test.

* style: apply Ruff formatting

* ruff fix again after rebase

* [run-slow] sam

* [run-slow] sam

* refactor: Address review comments and improve sub-config handling in SAM model tests

- Added attributes for sub_configs as per PR #34410.
- Enabled tests for configs, ensuring the composite model (SAM) has several sub-configs in the main config.
- Added class attribute _is_composite=True to the tester class
- test_sdpa_can_dispatch_composite_models added

* [run-slow] sam

* style: ruff

* [run-slow] sam

* style: ruff again ...

* [run-slow] sam
This commit is contained in:
Magnus
2024-12-17 14:46:05 +01:00
committed by GitHub
parent 747f361da1
commit 6eb00dd2f0
4 changed files with 256 additions and 31 deletions

View File

@@ -4202,16 +4202,20 @@ class ModelTesterMixin:
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
logits_eager = (
outputs_eager.hidden_states[-1]
if not is_encoder_decoder
else outputs_eager.decoder_hidden_states[-1]
)
logits_sdpa = (
outputs_sdpa.hidden_states[-1]
if not is_encoder_decoder
else outputs_sdpa.decoder_hidden_states[-1]
)
if hasattr(outputs_eager, "vision_hidden_states"):
logits_eager = outputs_eager.vision_hidden_states[-1]
logits_sdpa = outputs_sdpa.vision_hidden_states[-1]
else:
logits_eager = (
outputs_eager.hidden_states[-1]
if not is_encoder_decoder
else outputs_eager.decoder_hidden_states[-1]
)
logits_sdpa = (
outputs_sdpa.hidden_states[-1]
if not is_encoder_decoder
else outputs_sdpa.decoder_hidden_states[-1]
)
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
@@ -4287,6 +4291,8 @@ class ModelTesterMixin:
)
if config.model_type in ["idefics", "idefics2", "idefics3"]:
self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input")
if config.model_type in ["sam"]:
self.skipTest(reason="SAM requires an attention_mask input for relative positional embeddings")
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname: