Fix Qwen2AudioForConditionalGeneration.forward() and test_flash_attn_kernels_inference_equivalence (#39503)

* Add missing cache_position argument.

* Pass cache_position to language model.

* Overwrite prepare_inputs_for_generation.

* Set model to half precision for Flash Attention test.

* Cast model to bfloat16.
This commit is contained in:
Eric Bezzam
2025-07-28 16:35:08 +02:00
committed by GitHub
parent 28f2619868
commit 7623aa3e5f
3 changed files with 19 additions and 4 deletions

View File

@@ -19,7 +19,6 @@ from dataclasses import dataclass
from typing import Callable, Optional, Union from typing import Callable, Optional, Union
import torch import torch
import torch.utils.checkpoint
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
@@ -727,6 +726,7 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[tuple, Qwen2AudioCausalLMOutputWithPast]: ) -> Union[tuple, Qwen2AudioCausalLMOutputWithPast]:
r""" r"""
feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
@@ -845,6 +845,7 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
logits = outputs[0] logits = outputs[0]
@@ -878,5 +879,19 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
attention_mask=attention_mask, attention_mask=attention_mask,
) )
def prepare_inputs_for_generation(self, *args, **kwargs):
# Overwritten -- we should not pass input_features when we are in cached decoding stage
input_features = kwargs.pop("input_features", None)
cache_position = kwargs.get("cache_position")
model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
if cache_position is not None and cache_position[0] == 0:
# input_features should only be passed when we are not in cached decoding stage
model_inputs["input_features"] = input_features
return model_inputs
__all__ = ["Qwen2AudioForConditionalGeneration", "Qwen2AudioPreTrainedModel", "Qwen2AudioEncoder"] __all__ = ["Qwen2AudioForConditionalGeneration", "Qwen2AudioPreTrainedModel", "Qwen2AudioEncoder"]

View File

@@ -34,6 +34,7 @@ from transformers.testing_utils import (
torch_device, torch_device,
) )
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
@@ -132,14 +133,12 @@ class Qwen2AudioModelTester:
@require_torch @require_torch
class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase): class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
""" """
Model tester for `Qwen2AudioForConditionalGeneration`. Model tester for `Qwen2AudioForConditionalGeneration`.
""" """
all_model_classes = (Qwen2AudioForConditionalGeneration,) if is_torch_available() else () all_model_classes = (Qwen2AudioForConditionalGeneration,) if is_torch_available() else ()
# Doesn't run generation tests. TODO eustache/joao: some generation tests are broken, the errors seem cache-related
all_generative_model_classes = ()
test_pruning = False test_pruning = False
test_head_masking = False test_head_masking = False
_is_composite = True _is_composite = True

View File

@@ -3484,6 +3484,7 @@ class ModelTesterMixin:
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.to(torch.bfloat16)
dummy_input = inputs_dict[model.main_input_name][:1] dummy_input = inputs_dict[model.main_input_name][:1]
if dummy_input.dtype in [torch.float32, torch.float16]: if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16) dummy_input = dummy_input.to(torch.bfloat16)