Fix Qwen2AudioForConditionalGeneration.forward() and test_flash_attn_kernels_inference_equivalence (#39503)

* Add missing cache_position argument.

* Pass cache_position to language model.

* Overwrite prepare_inputs_for_generation.

* Set model to half precision for Flash Attention test.

* Cast model to bfloat16.
This commit is contained in:
Eric Bezzam
2025-07-28 16:35:08 +02:00
committed by GitHub
parent 28f2619868
commit 7623aa3e5f
3 changed files with 19 additions and 4 deletions

View File

@@ -34,6 +34,7 @@ from transformers.testing_utils import (
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
@@ -132,14 +133,12 @@ class Qwen2AudioModelTester:
@require_torch
class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase):
class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
"""
Model tester for `Qwen2AudioForConditionalGeneration`.
"""
all_model_classes = (Qwen2AudioForConditionalGeneration,) if is_torch_available() else ()
# Doesn't run generation tests. TODO eustache/joao: some generation tests are broken, the errors seem cache-related
all_generative_model_classes = ()
test_pruning = False
test_head_masking = False
_is_composite = True