[tests] Parameterized test_eager_matches_sdpa_inference (#36650)
This commit is contained in:
@@ -25,7 +25,6 @@ from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
require_bitsandbytes,
|
||||
require_torch,
|
||||
require_torch_sdpa,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -34,7 +33,13 @@ from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
from ...test_modeling_common import (
|
||||
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
|
||||
ModelTesterMixin,
|
||||
floats_tensor,
|
||||
ids_tensor,
|
||||
random_attention_mask,
|
||||
)
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
@@ -311,16 +316,12 @@ class IdeficsModelTester:
|
||||
def prepare_pixel_values(self):
|
||||
return floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
@require_torch_sdpa
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
self.skipTest(reason="Idefics has a hard requirement on SDPA, skipping this test")
|
||||
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
def test_eager_matches_sdpa_generate(self):
|
||||
self.skipTest(reason="Idefics has a hard requirement on SDPA, skipping this test")
|
||||
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
||||
@unittest.skip(reason="Idefics has a hard requirement on SDPA, skipping this test")
|
||||
def test_eager_matches_sdpa_inference(
|
||||
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -349,10 +350,11 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
|
||||
return inputs_dict
|
||||
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
@require_torch_sdpa
|
||||
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
||||
@unittest.skip("Idefics requires both text and image inputs which is currently not done in this test.")
|
||||
def test_eager_matches_sdpa_inference(self):
|
||||
def test_eager_matches_sdpa_inference(
|
||||
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
|
||||
):
|
||||
pass
|
||||
|
||||
def test_model_outputs_equivalence(self):
|
||||
@@ -597,10 +599,11 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
|
||||
)
|
||||
self.config_tester = ConfigTester(self, config_class=IdeficsConfig, hidden_size=37)
|
||||
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
@require_torch_sdpa
|
||||
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
||||
@unittest.skip("Idefics requires both text and image inputs which is currently not done in this test.")
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype):
|
||||
def test_eager_matches_sdpa_inference(
|
||||
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
|
||||
):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
|
||||
Reference in New Issue
Block a user