Fix SDPA attention precision issue in Qwen2.5-VL (#37363)
* solve conflicts and remove redundant attention_mask in qwenvit * update decoded text check * remove trailing whitespace
This commit is contained in:
@@ -296,7 +296,6 @@ class Glm4vVisionAttention(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
@@ -320,27 +319,51 @@ class Glm4vVisionAttention(nn.Module):
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
# Flash Attention 2: Use cu_seqlens for variable length attention
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask=attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
attention_mask=None,
|
||||
scaling=self.scaling,
|
||||
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
cu_seq_lens_q=cu_seqlens,
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
# Other implementations: Process each chunk separately
|
||||
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
splits = [
|
||||
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
|
||||
]
|
||||
|
||||
attn_outputs = [
|
||||
attention_interface(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attention_mask=None,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)[0]
|
||||
for q, k, v in zip(*splits)
|
||||
]
|
||||
attn_output = torch.cat(attn_outputs, dim=1)
|
||||
|
||||
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
||||
attn_output = self.proj(attn_output)
|
||||
@@ -361,7 +384,6 @@ class Glm4vVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
@@ -369,7 +391,6 @@ class Glm4vVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
@@ -467,25 +488,6 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb, pos_ids
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@@ -515,14 +517,12 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens)
|
||||
|
||||
for blk in self.blocks:
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = self.post_layernorm(hidden_states)
|
||||
|
||||
@@ -603,25 +603,6 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb, pos_ids
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@@ -651,14 +632,12 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens)
|
||||
|
||||
for blk in self.blocks:
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = self.post_layernorm(hidden_states)
|
||||
|
||||
@@ -957,7 +957,6 @@ class Qwen2_5OmniVisionAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
@@ -970,27 +969,51 @@ class Qwen2_5OmniVisionAttention(nn.Module):
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
# Flash Attention 2: Use cu_seqlens for variable length attention
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask=attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
attention_mask=None,
|
||||
scaling=self.scaling,
|
||||
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
cu_seq_lens_q=cu_seqlens,
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
# Other implementations: Process each chunk separately
|
||||
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
splits = [
|
||||
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
|
||||
]
|
||||
|
||||
attn_outputs = [
|
||||
attention_interface(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attention_mask=None,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)[0]
|
||||
for q, k, v in zip(*splits)
|
||||
]
|
||||
attn_output = torch.cat(attn_outputs, dim=1)
|
||||
|
||||
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
||||
attn_output = self.proj(attn_output)
|
||||
@@ -1024,14 +1047,12 @@ class Qwen2_5OmniVisionBlock(GradientCheckpointingLayer):
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
self.norm1(hidden_states),
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
@@ -1191,25 +1212,6 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
|
||||
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@@ -1257,12 +1259,10 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens_now,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.merger(hidden_states)
|
||||
|
||||
@@ -1935,7 +1935,6 @@ class Qwen2_5OmniVisionAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
@@ -1948,27 +1947,51 @@ class Qwen2_5OmniVisionAttention(nn.Module):
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
# Flash Attention 2: Use cu_seqlens for variable length attention
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask=attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
attention_mask=None,
|
||||
scaling=self.scaling,
|
||||
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
cu_seq_lens_q=cu_seqlens,
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
# Other implementations: Process each chunk separately
|
||||
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
splits = [
|
||||
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
|
||||
]
|
||||
|
||||
attn_outputs = [
|
||||
attention_interface(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attention_mask=None,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)[0]
|
||||
for q, k, v in zip(*splits)
|
||||
]
|
||||
attn_output = torch.cat(attn_outputs, dim=1)
|
||||
|
||||
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
||||
attn_output = self.proj(attn_output)
|
||||
@@ -1985,14 +2008,12 @@ class Qwen2_5OmniVisionBlock(Qwen2_5_VLVisionBlock):
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
self.norm1(hidden_states),
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
@@ -2007,25 +2028,6 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.blocks = nn.ModuleList([Qwen2_5OmniVisionBlock(config) for _ in range(config.depth)])
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@@ -2073,12 +2075,10 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel):
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens_now,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.merger(hidden_states)
|
||||
|
||||
@@ -215,7 +215,6 @@ class Qwen2_5_VLVisionAttention(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
@@ -239,27 +238,51 @@ class Qwen2_5_VLVisionAttention(nn.Module):
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
# Flash Attention 2: Use cu_seqlens for variable length attention
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask=attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
attention_mask=None,
|
||||
scaling=self.scaling,
|
||||
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
cu_seq_lens_q=cu_seqlens,
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
# Other implementations: Process each chunk separately
|
||||
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
splits = [
|
||||
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
|
||||
]
|
||||
|
||||
attn_outputs = [
|
||||
attention_interface(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attention_mask=None,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)[0]
|
||||
for q, k, v in zip(*splits)
|
||||
]
|
||||
attn_output = torch.cat(attn_outputs, dim=1)
|
||||
|
||||
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
||||
attn_output = self.proj(attn_output)
|
||||
@@ -280,7 +303,6 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
@@ -288,7 +310,6 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
@@ -422,25 +443,6 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@@ -488,12 +490,10 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens_now,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -159,7 +159,6 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
@@ -167,7 +166,6 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
@@ -289,25 +287,6 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@@ -355,12 +334,10 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens_now,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -333,7 +333,6 @@ class VisionAttention(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
@@ -357,27 +356,51 @@ class VisionAttention(nn.Module):
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
# Flash Attention 2: Use cu_seqlens for variable length attention
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask=attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
attention_mask=None,
|
||||
scaling=self.scaling,
|
||||
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
cu_seq_lens_q=cu_seqlens,
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
# Other implementations: Process each chunk separately
|
||||
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
splits = [
|
||||
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
|
||||
]
|
||||
|
||||
attn_outputs = [
|
||||
attention_interface(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attention_mask=None,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)[0]
|
||||
for q, k, v in zip(*splits)
|
||||
]
|
||||
attn_output = torch.cat(attn_outputs, dim=1)
|
||||
|
||||
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
||||
attn_output = self.proj(attn_output)
|
||||
@@ -400,7 +423,6 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
@@ -408,7 +430,6 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
@@ -721,25 +742,6 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@@ -765,14 +767,12 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
||||
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens)
|
||||
|
||||
for blk in self.blocks:
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -419,7 +419,7 @@ class Glm4vIntegrationTest(unittest.TestCase):
|
||||
output = model.generate(**inputs, max_new_tokens=30)
|
||||
|
||||
EXPECTED_DECODED_TEXT = [
|
||||
"\nWhat kind of dog is this?\n<think>Got it, let's look at the image. The animal in the picture has a stocky build, thick fur, and a face that's",
|
||||
"\nWhat kind of dog is this?\n<think>Got it, let's look at the image. The animal in the picture is not a dog; it's a cat. Specifically, it looks",
|
||||
"\nWhat kind of dog is this?\n<think>Got it, let's look at the image. Wait, the animals here are cats, not dogs. The question is about a dog, but"
|
||||
] # fmt: skip
|
||||
self.assertEqual(
|
||||
|
||||
Reference in New Issue
Block a user