[tests] Parameterized test_eager_matches_sdpa_inference (#36650)

This commit is contained in:
Joao Gante
2025-03-14 14:41:27 +00:00
committed by GitHub
parent 9215cc62d4
commit 42ebb6c23e
16 changed files with 285 additions and 1900 deletions

View File

@@ -44,7 +44,12 @@ from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_modeling_common import (
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
ModelTesterMixin,
floats_tensor,
ids_tensor,
)
from ...test_pipeline_mixin import PipelineTesterMixin
@@ -188,11 +193,15 @@ class MoshiDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
logits_processor_kwargs = {}
return logits_processor_kwargs
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@require_torch_sdpa
@slow
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
self.skipTest(reason="Moshi has no strict equivalence between two modes, skipping this test.")
def test_eager_matches_sdpa_inference(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
):
if use_attention_mask or (not use_attention_mask and torch_dtype == "fp32" and not output_attentions):
self.skipTest("Test is failing, fix me :) ")
parent_parameterized_test = getattr(ModelTesterMixin, self._testMethodName)
parent_parameterized_test(self)
# Copied from tests.test_modeling_common.ModelTesterMixin.test_resize_tokens_embeddings
def test_resize_tokens_embeddings(self):
@@ -620,11 +629,11 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
def test_beam_search_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("Adapting this test is costly. `test_eager_matches_sdpa_generate` tests this already.")
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@unittest.skip(reason="Unimplemented. Relies on `test_eager_matches_sdpa_generate` to check correctness.")
def test_eager_matches_sdpa_inference(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
):
pass
@unittest.skip(reason="The Moshi model does not have support dynamic compile yet")