Fix Qwen2AudioForConditionalGeneration.forward() and test_flash_attn_kernels_inference_equivalence (#39503)
* Add missing cache_position argument. * Pass cache_position to language model. * Overwrite prepare_inputs_for_generation. * Set model to half precision for Flash Attention test. * Cast model to bfloat16.
This commit is contained in:
@@ -3484,6 +3484,7 @@ class ModelTesterMixin:
|
||||
model = model_class(config)
|
||||
|
||||
model.to(torch_device)
|
||||
model.to(torch.bfloat16)
|
||||
dummy_input = inputs_dict[model.main_input_name][:1]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
|
||||
Reference in New Issue
Block a user