From 064cd7cdaca7c9dc7d4359434164f2c2c17505d4 Mon Sep 17 00:00:00 2001 From: Manuel Faysse <43467008+ManuelFay@users.noreply.github.com> Date: Fri, 28 Mar 2025 09:54:21 +0100 Subject: [PATCH] Fix SDPA implementation in Qwen2-VL (issues with torch==2.6.0) (#36891) * fix sdpa implementation * ruff * also modify 2_5 for consistency --- src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 6 ++++-- src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index fb0c8f5dbe..095a0769ed 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -316,8 +316,10 @@ class Qwen2_5_VLVisionSdpaAttention(nn.Module): q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) - attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) - attn_output = attn_output.transpose(0, 1) + attn_output = F.scaled_dot_product_attention( + q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attention_mask, dropout_p=0.0 + ) + attn_output = attn_output.squeeze(0).transpose(0, 1) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 66b780f312..2bcfd3608d 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -416,8 +416,10 @@ class VisionSdpaAttention(nn.Module): q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) - attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) - attn_output = attn_output.transpose(0, 1) + attn_output = F.scaled_dot_product_attention( + q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attention_mask, dropout_p=0.0 + ) + attn_output = attn_output.squeeze(0).transpose(0, 1) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output