[tests] Parameterized test_eager_matches_sdpa_inference (#36650)

This commit is contained in:
Joao Gante
2025-03-14 14:41:27 +00:00
committed by GitHub
parent 9215cc62d4
commit 42ebb6c23e
16 changed files with 285 additions and 1900 deletions

View File

@@ -25,7 +25,6 @@ from transformers.testing_utils import (
TestCasePlus,
require_bitsandbytes,
require_torch,
require_torch_sdpa,
require_vision,
slow,
torch_device,
@@ -34,7 +33,13 @@ from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_modeling_common import (
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
ModelTesterMixin,
floats_tensor,
ids_tensor,
random_attention_mask,
)
from ...test_pipeline_mixin import PipelineTesterMixin
@@ -311,16 +316,12 @@ class IdeficsModelTester:
def prepare_pixel_values(self):
return floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@require_torch_sdpa
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
self.skipTest(reason="Idefics has a hard requirement on SDPA, skipping this test")
@require_torch_sdpa
@slow
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
def test_eager_matches_sdpa_generate(self):
self.skipTest(reason="Idefics has a hard requirement on SDPA, skipping this test")
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@unittest.skip(reason="Idefics has a hard requirement on SDPA, skipping this test")
def test_eager_matches_sdpa_inference(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
):
pass
@require_torch
@@ -349,10 +350,11 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
return inputs_dict
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@unittest.skip("Idefics requires both text and image inputs which is currently not done in this test.")
def test_eager_matches_sdpa_inference(self):
def test_eager_matches_sdpa_inference(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
):
pass
def test_model_outputs_equivalence(self):
@@ -597,10 +599,11 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
)
self.config_tester = ConfigTester(self, config_class=IdeficsConfig, hidden_size=37)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@unittest.skip("Idefics requires both text and image inputs which is currently not done in this test.")
def test_eager_matches_sdpa_inference(self, torch_dtype):
def test_eager_matches_sdpa_inference(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
):
pass
@pytest.mark.generate