Fix Qwen2AudioForConditionalGeneration.forward() and test_flash_attn_kernels_inference_equivalence (#39503)

* Add missing cache_position argument.

* Pass cache_position to language model.

* Overwrite prepare_inputs_for_generation.

* Set model to half precision for Flash Attention test.

* Cast model to bfloat16.
This commit is contained in:
Eric Bezzam
2025-07-28 16:35:08 +02:00
committed by GitHub
parent 28f2619868
commit 7623aa3e5f
3 changed files with 19 additions and 4 deletions

View File

@@ -19,7 +19,6 @@ from dataclasses import dataclass
from typing import Callable, Optional, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
@@ -727,6 +726,7 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[tuple, Qwen2AudioCausalLMOutputWithPast]:
r"""
feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
@@ -845,6 +845,7 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
logits = outputs[0]
@@ -878,5 +879,19 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
attention_mask=attention_mask,
)
def prepare_inputs_for_generation(self, *args, **kwargs):
# Overwritten -- we should not pass input_features when we are in cached decoding stage
input_features = kwargs.pop("input_features", None)
cache_position = kwargs.get("cache_position")
model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
if cache_position is not None and cache_position[0] == 0:
# input_features should only be passed when we are not in cached decoding stage
model_inputs["input_features"] = input_features
return model_inputs
__all__ = ["Qwen2AudioForConditionalGeneration", "Qwen2AudioPreTrainedModel", "Qwen2AudioEncoder"]

View File

@@ -34,6 +34,7 @@ from transformers.testing_utils import (
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
@@ -132,14 +133,12 @@ class Qwen2AudioModelTester:
@require_torch
class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase):
class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
"""
Model tester for `Qwen2AudioForConditionalGeneration`.
"""
all_model_classes = (Qwen2AudioForConditionalGeneration,) if is_torch_available() else ()
# Doesn't run generation tests. TODO eustache/joao: some generation tests are broken, the errors seem cache-related
all_generative_model_classes = ()
test_pruning = False
test_head_masking = False
_is_composite = True

View File

@@ -3484,6 +3484,7 @@ class ModelTesterMixin:
model = model_class(config)
model.to(torch_device)
model.to(torch.bfloat16)
dummy_input = inputs_dict[model.main_input_name][:1]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)