Fix Qwen2AudioForConditionalGeneration.forward() and test_flash_attn_kernels_inference_equivalence (#39503)

* Add missing cache_position argument.

* Pass cache_position to language model.

* Overwrite prepare_inputs_for_generation.

* Set model to half precision for Flash Attention test.

* Cast model to bfloat16.
This commit is contained in:
Eric Bezzam
2025-07-28 16:35:08 +02:00
committed by GitHub
parent 28f2619868
commit 7623aa3e5f
3 changed files with 19 additions and 4 deletions

View File

@@ -3484,6 +3484,7 @@ class ModelTesterMixin:
model = model_class(config)
model.to(torch_device)
model.to(torch.bfloat16)
dummy_input = inputs_dict[model.main_input_name][:1]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)