[tests] Parameterized test_eager_matches_sdpa_inference (#36650)
This commit is contained in:
@@ -44,7 +44,12 @@ from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
from ...test_modeling_common import (
|
||||
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
|
||||
ModelTesterMixin,
|
||||
floats_tensor,
|
||||
ids_tensor,
|
||||
)
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
@@ -188,11 +193,15 @@ class MoshiDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
logits_processor_kwargs = {}
|
||||
return logits_processor_kwargs
|
||||
|
||||
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
self.skipTest(reason="Moshi has no strict equivalence between two modes, skipping this test.")
|
||||
def test_eager_matches_sdpa_inference(
|
||||
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
|
||||
):
|
||||
if use_attention_mask or (not use_attention_mask and torch_dtype == "fp32" and not output_attentions):
|
||||
self.skipTest("Test is failing, fix me :) ")
|
||||
parent_parameterized_test = getattr(ModelTesterMixin, self._testMethodName)
|
||||
parent_parameterized_test(self)
|
||||
|
||||
# Copied from tests.test_modeling_common.ModelTesterMixin.test_resize_tokens_embeddings
|
||||
def test_resize_tokens_embeddings(self):
|
||||
@@ -620,11 +629,11 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Adapting this test is costly. `test_eager_matches_sdpa_generate` tests this already.")
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
||||
@unittest.skip(reason="Unimplemented. Relies on `test_eager_matches_sdpa_generate` to check correctness.")
|
||||
def test_eager_matches_sdpa_inference(
|
||||
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
|
||||
):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="The Moshi model does not have support dynamic compile yet")
|
||||
|
||||
Reference in New Issue
Block a user