Fix Qwen2AudioForConditionalGeneration.forward() and test_flash_attn_kernels_inference_equivalence (#39503)
* Add missing cache_position argument. * Pass cache_position to language model. * Overwrite prepare_inputs_for_generation. * Set model to half precision for Flash Attention test. * Cast model to bfloat16.
This commit is contained in:
@@ -19,7 +19,6 @@ from dataclasses import dataclass
|
|||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
@@ -727,6 +726,7 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[tuple, Qwen2AudioCausalLMOutputWithPast]:
|
) -> Union[tuple, Qwen2AudioCausalLMOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
|
feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
|
||||||
@@ -845,6 +845,7 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
logits = outputs[0]
|
logits = outputs[0]
|
||||||
@@ -878,5 +879,19 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, *args, **kwargs):
|
||||||
|
# Overwritten -- we should not pass input_features when we are in cached decoding stage
|
||||||
|
|
||||||
|
input_features = kwargs.pop("input_features", None)
|
||||||
|
cache_position = kwargs.get("cache_position")
|
||||||
|
|
||||||
|
model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
|
||||||
|
|
||||||
|
if cache_position is not None and cache_position[0] == 0:
|
||||||
|
# input_features should only be passed when we are not in cached decoding stage
|
||||||
|
model_inputs["input_features"] = input_features
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Qwen2AudioForConditionalGeneration", "Qwen2AudioPreTrainedModel", "Qwen2AudioEncoder"]
|
__all__ = ["Qwen2AudioForConditionalGeneration", "Qwen2AudioPreTrainedModel", "Qwen2AudioEncoder"]
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||||
|
|
||||||
@@ -132,14 +133,12 @@ class Qwen2AudioModelTester:
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase):
|
class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Model tester for `Qwen2AudioForConditionalGeneration`.
|
Model tester for `Qwen2AudioForConditionalGeneration`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
all_model_classes = (Qwen2AudioForConditionalGeneration,) if is_torch_available() else ()
|
all_model_classes = (Qwen2AudioForConditionalGeneration,) if is_torch_available() else ()
|
||||||
# Doesn't run generation tests. TODO eustache/joao: some generation tests are broken, the errors seem cache-related
|
|
||||||
all_generative_model_classes = ()
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
_is_composite = True
|
_is_composite = True
|
||||||
|
|||||||
@@ -3484,6 +3484,7 @@ class ModelTesterMixin:
|
|||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
model.to(torch.bfloat16)
|
||||||
dummy_input = inputs_dict[model.main_input_name][:1]
|
dummy_input = inputs_dict[model.main_input_name][:1]
|
||||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||||
dummy_input = dummy_input.to(torch.bfloat16)
|
dummy_input = dummy_input.to(torch.bfloat16)
|
||||||
|
|||||||
Reference in New Issue
Block a user