Fix Qwen2AudioForConditionalGeneration.forward() and test_flash_attn_kernels_inference_equivalence (#39503)

* Add missing cache_position argument.

* Pass cache_position to language model.

* Overwrite prepare_inputs_for_generation.

* Set model to half precision for Flash Attention test.

* Cast model to bfloat16.
This commit is contained in:
Eric Bezzam
2025-07-28 16:35:08 +02:00
committed by GitHub
parent 28f2619868
commit 7623aa3e5f
3 changed files with 19 additions and 4 deletions

View File

@@ -34,6 +34,7 @@ from transformers.testing_utils import (
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
@@ -132,14 +133,12 @@ class Qwen2AudioModelTester:
@require_torch
class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase):
class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
"""
Model tester for `Qwen2AudioForConditionalGeneration`.
"""
all_model_classes = (Qwen2AudioForConditionalGeneration,) if is_torch_available() else ()
# Doesn't run generation tests. TODO eustache/joao: some generation tests are broken, the errors seem cache-related
all_generative_model_classes = ()
test_pruning = False
test_head_masking = False
_is_composite = True

View File

@@ -3484,6 +3484,7 @@ class ModelTesterMixin:
model = model_class(config)
model.to(torch_device)
model.to(torch.bfloat16)
dummy_input = inputs_dict[model.main_input_name][:1]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)