Support for SDPA for SAM models (#34110)
* feat: add support for sdpa and gradient checkpointing * fix: ruff format * fix: config sdpa * fix: sdpa layer naming convention * fix: update test_eager_matches_sdpa_inference to handle vision_hidden_states * test: skip incompatible tests and fix loading issue with sdpa - Updated tests to skip cases flash and dynamic compile. - Minor adjustment to ensure correct loading of model with sdpa for dispatch test. * style: apply Ruff formatting * ruff fix again after rebase * [run-slow] sam * [run-slow] sam * refactor: Address review comments and improve sub-config handling in SAM model tests - Added attributes for sub_configs as per PR #34410. - Enabled tests for configs, ensuring the composite model (SAM) has several sub-configs in the main config. - Added class attribute _is_composite=True to the tester class - test_sdpa_can_dispatch_composite_models added * [run-slow] sam * style: ruff * [run-slow] sam * style: ruff again ... * [run-slow] sam
This commit is contained in:
@@ -4202,16 +4202,20 @@ class ModelTesterMixin:
|
||||
outputs_eager = model_eager(**prepared_inputs)
|
||||
outputs_sdpa = model_sdpa(**prepared_inputs)
|
||||
|
||||
logits_eager = (
|
||||
outputs_eager.hidden_states[-1]
|
||||
if not is_encoder_decoder
|
||||
else outputs_eager.decoder_hidden_states[-1]
|
||||
)
|
||||
logits_sdpa = (
|
||||
outputs_sdpa.hidden_states[-1]
|
||||
if not is_encoder_decoder
|
||||
else outputs_sdpa.decoder_hidden_states[-1]
|
||||
)
|
||||
if hasattr(outputs_eager, "vision_hidden_states"):
|
||||
logits_eager = outputs_eager.vision_hidden_states[-1]
|
||||
logits_sdpa = outputs_sdpa.vision_hidden_states[-1]
|
||||
else:
|
||||
logits_eager = (
|
||||
outputs_eager.hidden_states[-1]
|
||||
if not is_encoder_decoder
|
||||
else outputs_eager.decoder_hidden_states[-1]
|
||||
)
|
||||
logits_sdpa = (
|
||||
outputs_sdpa.hidden_states[-1]
|
||||
if not is_encoder_decoder
|
||||
else outputs_sdpa.decoder_hidden_states[-1]
|
||||
)
|
||||
|
||||
if torch_device in ["cpu", "cuda"]:
|
||||
atol = atols[torch_device, enable_kernels, torch_dtype]
|
||||
@@ -4287,6 +4291,8 @@ class ModelTesterMixin:
|
||||
)
|
||||
if config.model_type in ["idefics", "idefics2", "idefics3"]:
|
||||
self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input")
|
||||
if config.model_type in ["sam"]:
|
||||
self.skipTest(reason="SAM requires an attention_mask input for relative positional embeddings")
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
||||
Reference in New Issue
Block a user