diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 3bbfb57c82..57d50c181c 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -296,7 +296,6 @@ class Glm4vVisionAttention(nn.Module): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] @@ -320,27 +319,51 @@ class Glm4vVisionAttention(nn.Module): query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output, _ = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seq_lens_k=cu_seqlens, - max_length_q=max_seqlen, - max_length_k=max_seqlen, - is_causal=False, - **kwargs, - ) + if self.config._attn_implementation == "flash_attention_2": + # Flash Attention 2: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.proj(attn_output) @@ -361,7 +384,6 @@ class Glm4vVisionBlock(GradientCheckpointingLayer): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( @@ -369,7 +391,6 @@ class Glm4vVisionBlock(GradientCheckpointingLayer): cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, - attention_mask=attention_mask, **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) @@ -467,25 +488,6 @@ class Glm4vVisionModel(Glm4vPreTrainedModel): rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb, pos_ids - def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention masl only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if self.config._attn_implementation == "flash_attention_2": - return None - - seq_length = inputs_tensor.shape[0] - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(inputs_tensor.dtype).min, - device=inputs_tensor.device, - dtype=inputs_tensor.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - return attention_mask - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: """ Args: @@ -515,14 +517,12 @@ class Glm4vVisionModel(Glm4vPreTrainedModel): cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]) - attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens) for blk in self.blocks: hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, - attention_mask=attention_mask, ) hidden_states = self.post_layernorm(hidden_states) diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index 4b8ef1754b..3f8f593c20 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -603,25 +603,6 @@ class Glm4vVisionModel(Glm4vPreTrainedModel): rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb, pos_ids - def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention masl only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if self.config._attn_implementation == "flash_attention_2": - return None - - seq_length = inputs_tensor.shape[0] - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(inputs_tensor.dtype).min, - device=inputs_tensor.device, - dtype=inputs_tensor.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - return attention_mask - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: """ Args: @@ -651,14 +632,12 @@ class Glm4vVisionModel(Glm4vPreTrainedModel): cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]) - attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens) for blk in self.blocks: hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, - attention_mask=attention_mask, ) hidden_states = self.post_layernorm(hidden_states) diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 3a55c3a351..c76c6a447c 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -957,7 +957,6 @@ class Qwen2_5OmniVisionAttention(nn.Module): hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] @@ -970,27 +969,51 @@ class Qwen2_5OmniVisionAttention(nn.Module): query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output, _ = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seq_lens_k=cu_seqlens, - max_length_q=max_seqlen, - max_length_k=max_seqlen, - is_causal=False, - **kwargs, - ) + if self.config._attn_implementation == "flash_attention_2": + # Flash Attention 2: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.proj(attn_output) @@ -1024,14 +1047,12 @@ class Qwen2_5OmniVisionBlock(GradientCheckpointingLayer): hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, - attention_mask=attention_mask, **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) @@ -1191,25 +1212,6 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel): return window_index, cu_window_seqlens - def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention masl only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if self.config._attn_implementation == "flash_attention_2": - return None - - seq_length = inputs_tensor.shape[0] - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(inputs_tensor.dtype).min, - device=inputs_tensor.device, - dtype=inputs_tensor.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - return attention_mask - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: @@ -1257,12 +1259,10 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel): else: cu_seqlens_now = cu_window_seqlens - attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now) hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb, - attention_mask=attention_mask, **kwargs, ) hidden_states = self.merger(hidden_states) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 5ee0f347dd..c64e89a9ef 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1935,7 +1935,6 @@ class Qwen2_5OmniVisionAttention(nn.Module): hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] @@ -1948,27 +1947,51 @@ class Qwen2_5OmniVisionAttention(nn.Module): query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output, _ = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seq_lens_k=cu_seqlens, - max_length_q=max_seqlen, - max_length_k=max_seqlen, - is_causal=False, - **kwargs, - ) + if self.config._attn_implementation == "flash_attention_2": + # Flash Attention 2: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.proj(attn_output) @@ -1985,14 +2008,12 @@ class Qwen2_5OmniVisionBlock(Qwen2_5_VLVisionBlock): hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, - attention_mask=attention_mask, **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) @@ -2007,25 +2028,6 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel): super().__init__(config, *inputs, **kwargs) self.blocks = nn.ModuleList([Qwen2_5OmniVisionBlock(config) for _ in range(config.depth)]) - def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention masl only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if self.config._attn_implementation == "flash_attention_2": - return None - - seq_length = inputs_tensor.shape[0] - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(inputs_tensor.dtype).min, - device=inputs_tensor.device, - dtype=inputs_tensor.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - return attention_mask - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: @@ -2073,12 +2075,10 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel): else: cu_seqlens_now = cu_window_seqlens - attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now) hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb, - attention_mask=attention_mask, **kwargs, ) hidden_states = self.merger(hidden_states) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 97d7791faa..0285d18cd2 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -215,7 +215,6 @@ class Qwen2_5_VLVisionAttention(nn.Module): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] @@ -239,27 +238,51 @@ class Qwen2_5_VLVisionAttention(nn.Module): query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output, _ = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seq_lens_k=cu_seqlens, - max_length_q=max_seqlen, - max_length_k=max_seqlen, - is_causal=False, - **kwargs, - ) + if self.config._attn_implementation == "flash_attention_2": + # Flash Attention 2: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.proj(attn_output) @@ -280,7 +303,6 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( @@ -288,7 +310,6 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, - attention_mask=attention_mask, **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) @@ -422,25 +443,6 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): return window_index, cu_window_seqlens - def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention masl only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if self.config._attn_implementation == "flash_attention_2": - return None - - seq_length = inputs_tensor.shape[0] - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(inputs_tensor.dtype).min, - device=inputs_tensor.device, - dtype=inputs_tensor.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - return attention_mask - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: @@ -488,12 +490,10 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): else: cu_seqlens_now = cu_window_seqlens - attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now) hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, - attention_mask=attention_mask, **kwargs, ) diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index f18e0b3461..07f41356e8 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -159,7 +159,6 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( @@ -167,7 +166,6 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, - attention_mask=attention_mask, **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) @@ -289,25 +287,6 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): return window_index, cu_window_seqlens - def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention masl only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if self.config._attn_implementation == "flash_attention_2": - return None - - seq_length = inputs_tensor.shape[0] - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(inputs_tensor.dtype).min, - device=inputs_tensor.device, - dtype=inputs_tensor.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - return attention_mask - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: @@ -355,12 +334,10 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): else: cu_seqlens_now = cu_window_seqlens - attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now) hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, - attention_mask=attention_mask, **kwargs, ) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 2cd1a61b80..a8b2ebf1a9 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -333,7 +333,6 @@ class VisionAttention(nn.Module): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] @@ -357,27 +356,51 @@ class VisionAttention(nn.Module): query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output, _ = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seq_lens_k=cu_seqlens, - max_length_q=max_seqlen, - max_length_k=max_seqlen, - is_causal=False, - **kwargs, - ) + if self.config._attn_implementation == "flash_attention_2": + # Flash Attention 2: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.proj(attn_output) @@ -400,7 +423,6 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( @@ -408,7 +430,6 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer): cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, - attention_mask=attention_mask, **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) @@ -721,25 +742,6 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb - def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention masl only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if self.config._attn_implementation == "flash_attention_2": - return None - - seq_length = inputs_tensor.shape[0] - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(inputs_tensor.dtype).min, - device=inputs_tensor.device, - dtype=inputs_tensor.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - return attention_mask - @auto_docstring def forward( self, @@ -765,14 +767,12 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens) for blk in self.blocks: hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, - attention_mask=attention_mask, **kwargs, ) diff --git a/tests/models/glm4v/test_modeling_glm4v.py b/tests/models/glm4v/test_modeling_glm4v.py index 39b66875c2..e5211e59f2 100644 --- a/tests/models/glm4v/test_modeling_glm4v.py +++ b/tests/models/glm4v/test_modeling_glm4v.py @@ -419,7 +419,7 @@ class Glm4vIntegrationTest(unittest.TestCase): output = model.generate(**inputs, max_new_tokens=30) EXPECTED_DECODED_TEXT = [ - "\nWhat kind of dog is this?\nGot it, let's look at the image. The animal in the picture has a stocky build, thick fur, and a face that's", + "\nWhat kind of dog is this?\nGot it, let's look at the image. The animal in the picture is not a dog; it's a cat. Specifically, it looks", "\nWhat kind of dog is this?\nGot it, let's look at the image. Wait, the animals here are cats, not dogs. The question is about a dog, but" ] # fmt: skip self.assertEqual(