Fix SDPA implementation in Qwen2-VL (issues with torch==2.6.0) (#36891)
* fix sdpa implementation * ruff * also modify 2_5 for consistency
This commit is contained in:
@@ -316,8 +316,10 @@ class Qwen2_5_VLVisionSdpaAttention(nn.Module):
|
|||||||
q = q.transpose(0, 1)
|
q = q.transpose(0, 1)
|
||||||
k = k.transpose(0, 1)
|
k = k.transpose(0, 1)
|
||||||
v = v.transpose(0, 1)
|
v = v.transpose(0, 1)
|
||||||
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
attn_output = F.scaled_dot_product_attention(
|
||||||
attn_output = attn_output.transpose(0, 1)
|
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attention_mask, dropout_p=0.0
|
||||||
|
)
|
||||||
|
attn_output = attn_output.squeeze(0).transpose(0, 1)
|
||||||
attn_output = attn_output.reshape(seq_length, -1)
|
attn_output = attn_output.reshape(seq_length, -1)
|
||||||
attn_output = self.proj(attn_output)
|
attn_output = self.proj(attn_output)
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|||||||
@@ -416,8 +416,10 @@ class VisionSdpaAttention(nn.Module):
|
|||||||
q = q.transpose(0, 1)
|
q = q.transpose(0, 1)
|
||||||
k = k.transpose(0, 1)
|
k = k.transpose(0, 1)
|
||||||
v = v.transpose(0, 1)
|
v = v.transpose(0, 1)
|
||||||
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
attn_output = F.scaled_dot_product_attention(
|
||||||
attn_output = attn_output.transpose(0, 1)
|
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attention_mask, dropout_p=0.0
|
||||||
|
)
|
||||||
|
attn_output = attn_output.squeeze(0).transpose(0, 1)
|
||||||
attn_output = attn_output.reshape(seq_length, -1)
|
attn_output = attn_output.reshape(seq_length, -1)
|
||||||
attn_output = self.proj(attn_output)
|
attn_output = self.proj(attn_output)
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|||||||
Reference in New Issue
Block a user