Fix Qwen2AudioForConditionalGeneration.forward() and test_flash_attn_kernels_inference_equivalence (#39503)
* Add missing cache_position argument. * Pass cache_position to language model. * Overwrite prepare_inputs_for_generation. * Set model to half precision for Flash Attention test. * Cast model to bfloat16.
This commit is contained in:
@@ -34,6 +34,7 @@ from transformers.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
@@ -132,14 +133,12 @@ class Qwen2AudioModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Model tester for `Qwen2AudioForConditionalGeneration`.
|
||||
"""
|
||||
|
||||
all_model_classes = (Qwen2AudioForConditionalGeneration,) if is_torch_available() else ()
|
||||
# Doesn't run generation tests. TODO eustache/joao: some generation tests are broken, the errors seem cache-related
|
||||
all_generative_model_classes = ()
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
_is_composite = True
|
||||
|
||||
@@ -3484,6 +3484,7 @@ class ModelTesterMixin:
|
||||
model = model_class(config)
|
||||
|
||||
model.to(torch_device)
|
||||
model.to(torch.bfloat16)
|
||||
dummy_input = inputs_dict[model.main_input_name][:1]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
|
||||
Reference in New Issue
Block a user