From 9a4ce6477019358abc3ebd72d435da56f4c0ab7c Mon Sep 17 00:00:00 2001 From: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Date: Wed, 16 Apr 2025 18:15:22 +0200 Subject: [PATCH] :red_circle: Update CLIP vision attention to new attention interface (#37498) * update attention interface * fix test * propagate attention changes * revert weird changes * fix modular * what? * ruff is mocking me * ruff being ruff * simplify test suite + fix FA2 * fixup tests + propagate FA2 fixes * add Copied From where relevant * fix conflict between copies and modular * recover FA2 training for CLIP + handle quantization * don't ditch the warning * tiny import fix * code review (FA2 support, copied from) * fix style * modularity * wrong copies * future-proofing for TP * mlcd inherits from CLIP --- .../models/altclip/modeling_altclip.py | 139 ++++---- src/transformers/models/clip/modeling_clip.py | 300 ++++-------------- .../models/clipseg/modeling_clipseg.py | 132 ++++---- src/transformers/models/clvp/modeling_clvp.py | 1 - src/transformers/models/git/modeling_git.py | 134 ++++---- src/transformers/models/idefics/vision.py | 137 ++++---- .../models/kosmos2/modeling_kosmos2.py | 131 ++++---- src/transformers/models/mlcd/modeling_mlcd.py | 43 ++- .../models/x_clip/modeling_x_clip.py | 136 ++++---- tests/models/clip/test_modeling_clip.py | 242 ++------------ 10 files changed, 516 insertions(+), 879 deletions(-) diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 6e4c9e650d..9c8be7ea80 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -30,9 +30,15 @@ from ...modeling_outputs import ( BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndProjection, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, torch_int +from ...utils import ( + ModelOutput, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig @@ -721,7 +727,29 @@ class AltRobertaPooler(nn.Module): return pooled_output -# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->AltCLIP +# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + class AltCLIPAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -738,15 +766,13 @@ class AltCLIPAttention(nn.Module): ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout + self.is_causal = False self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -756,74 +782,51 @@ class AltCLIPAttention(nn.Module): ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" - bsz, tgt_len, embed_dim = hidden_states.size() + batch_size, seq_length, embed_dim = hidden_states.shape - # get query proj - query_states = self.q_proj(hidden_states) * self.scale - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - # apply the causal_attention_mask first - if causal_attention_mask is not None: - if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" - f" {causal_attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + # CLIP text model uses both `causal_attention_mask` and `attention_mask` + # in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask` + if self.config._attn_implementation != "flash_attention_2": + if attention_mask is not None and causal_attention_mask is not None: + attention_mask = attention_mask + causal_attention_mask + elif causal_attention_mask is not None: + attention_mask = causal_attention_mask else: - attn_weights_reshaped = None + self.is_causal = causal_attention_mask is not None - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped + if not output_attentions: + attn_weights = None + return attn_output, attn_weights # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->AltCLIP diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 2df825a8ad..4136aabf1d 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -15,19 +15,16 @@ """PyTorch CLIP model.""" from dataclasses import dataclass -from typing import Any, Optional, Tuple +from typing import Any, Callable, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput -from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import is_torch_greater_or_equal_than_2_2 +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -41,10 +38,6 @@ from ...utils import ( from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) # General docstring @@ -297,10 +290,34 @@ class CLIPTextEmbeddings(nn.Module): return embeddings +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + output_attentions: bool = True, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + if not output_attentions: + attn_weights = None + return attn_output, attn_weights + + class CLIPAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config): + def __init__(self, config: Union[CLIPVisionConfig, CLIPTextConfig]): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -313,15 +330,13 @@ class CLIPAttention(nn.Module): ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout + self.is_causal = False self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -331,242 +346,55 @@ class CLIPAttention(nn.Module): ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" - bsz, tgt_len, embed_dim = hidden_states.size() + batch_size, seq_length, embed_dim = hidden_states.shape - # get query proj - query_states = self.q_proj(hidden_states) * self.scale - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - # apply the causal_attention_mask first - if causal_attention_mask is not None: - if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" - f" {causal_attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + queries = queries.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2) + # CLIP text model uses both `causal_attention_mask` and `attention_mask` + # in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask` + if self.config._attn_implementation == "flash_attention_2": + self.is_causal = causal_attention_mask is not None else: - attn_weights_reshaped = None + if attention_mask is not None and causal_attention_mask is not None: + attention_mask = attention_mask + causal_attention_mask + elif causal_attention_mask is not None: + attention_mask = causal_attention_mask - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped - - -class CLIPFlashAttention2(CLIPAttention): - """ - CLIPAttention flash attention module. This module inherits from `CLIPAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - causal_attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - output_attentions = False - - batch_size, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim) - value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim) - - dropout_rate = self.dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, attention_mask, - q_len, - dropout=dropout_rate, - is_causal=causal_attention_mask is not None, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + output_attentions=output_attentions, ) - attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() attn_output = self.out_proj(attn_output) if not output_attentions: attn_weights = None - return attn_output, attn_weights -class CLIPSdpaAttention(CLIPAttention): - """ - SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `CLIPAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from CLIPAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - causal_attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "CLIPModel is using CLIPSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not " - "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying " - "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can " - 'be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - causal_attention_mask=causal_attention_mask, - output_attentions=output_attentions, - ) - - # CLIP text model uses both `causal_attention_mask` and `attention_mask` - if attention_mask is not None and causal_attention_mask is not None: - attn_mask = attention_mask + causal_attention_mask - elif causal_attention_mask is not None: - attn_mask = causal_attention_mask - else: - attn_mask = attention_mask - - bsz, tgt_len, embed_dim = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if not is_torch_greater_or_equal_than_2_2 and query_states.device.type == "cuda" and attn_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # CLIP text model uses both `causal_attention_mask` and `attention_mask` sequentially. - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attn_mask, - dropout_p=self.dropout if self.training else 0.0, - scale=self.scale, - ) - - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None - - -CLIP_ATTENTION_CLASSES = { - "eager": CLIPAttention, - "sdpa": CLIPSdpaAttention, - "flash_attention_2": CLIPFlashAttention2, -} - - class CLIPMLP(nn.Module): def __init__(self, config): super().__init__() @@ -583,10 +411,10 @@ class CLIPMLP(nn.Module): class CLIPEncoderLayer(nn.Module): - def __init__(self, config: CLIPConfig): + def __init__(self, config: Union[CLIPVisionConfig, CLIPTextConfig]): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = CLIP_ATTENTION_CLASSES[config._attn_implementation](config) + self.self_attn = CLIPAttention(config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = CLIPMLP(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -952,7 +780,7 @@ class CLIPTextTransformer(nn.Module): # expand attention_mask if attention_mask is not None and not self._use_flash_attention_2: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) encoder_outputs: BaseModelOutput = self.encoder( diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index a24847471f..d626914d72 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -17,7 +17,7 @@ import copy import math from dataclasses import dataclass -from typing import Any, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -26,7 +26,7 @@ from torch import nn from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ( ModelOutput, add_start_docstrings, @@ -264,11 +264,33 @@ class CLIPSegTextEmbeddings(nn.Module): return embeddings -# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->CLIPSeg +# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + class CLIPSegAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config): + def __init__(self, config: Union[CLIPSegVisionConfig, CLIPSegTextConfig]): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -281,15 +303,13 @@ class CLIPSegAttention(nn.Module): ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout + self.is_causal = False self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -299,74 +319,52 @@ class CLIPSegAttention(nn.Module): ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" - bsz, tgt_len, embed_dim = hidden_states.size() + batch_size, seq_length, embed_dim = hidden_states.shape - # get query proj - query_states = self.q_proj(hidden_states) * self.scale - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - # apply the causal_attention_mask first - if causal_attention_mask is not None: - if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" - f" {causal_attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + # CLIP text model uses both `causal_attention_mask` and `attention_mask` + # in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask` + if self.config._attn_implementation != "flash_attention_2": + if attention_mask is not None and causal_attention_mask is not None: + attention_mask = attention_mask + causal_attention_mask + elif causal_attention_mask is not None: + attention_mask = causal_attention_mask else: - attn_weights_reshaped = None + self.is_causal = causal_attention_mask is not None - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() attn_output = self.out_proj(attn_output) + if not output_attentions: + attn_weights = None - return attn_output, attn_weights_reshaped + return attn_output, attn_weights # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->CLIPSeg diff --git a/src/transformers/models/clvp/modeling_clvp.py b/src/transformers/models/clvp/modeling_clvp.py index afbab29383..320cf17126 100644 --- a/src/transformers/models/clvp/modeling_clvp.py +++ b/src/transformers/models/clvp/modeling_clvp.py @@ -303,7 +303,6 @@ class ClvpSelfAttention(nn.Module): self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_attention_bias) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) - # Copied from transformers.models.clip.modeling_clip.CLIPAttention._shape def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 7efdf2d45c..047b3a1fc4 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -17,7 +17,7 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -25,7 +25,6 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache -from ...file_utils import ModelOutput from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import ( @@ -34,9 +33,10 @@ from ...modeling_outputs import ( BaseModelOutputWithPooling, CausalLMOutputWithPast, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( + ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging, @@ -698,7 +698,6 @@ class GitVisionEmbeddings(nn.Module): return embeddings -# Copied from transformers.models.clip.modeling_clip.CLIPMLP class GitVisionMLP(nn.Module): def __init__(self, config): super().__init__() @@ -714,7 +713,29 @@ class GitVisionMLP(nn.Module): return hidden_states -# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->GitVision +# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + class GitVisionAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -731,15 +752,13 @@ class GitVisionAttention(nn.Module): ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout + self.is_causal = False self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -749,74 +768,51 @@ class GitVisionAttention(nn.Module): ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" - bsz, tgt_len, embed_dim = hidden_states.size() + batch_size, seq_length, embed_dim = hidden_states.shape - # get query proj - query_states = self.q_proj(hidden_states) * self.scale - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - # apply the causal_attention_mask first - if causal_attention_mask is not None: - if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" - f" {causal_attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + # CLIP text model uses both `causal_attention_mask` and `attention_mask` + # in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask` + if self.config._attn_implementation != "flash_attention_2": + if attention_mask is not None and causal_attention_mask is not None: + attention_mask = attention_mask + causal_attention_mask + elif causal_attention_mask is not None: + attention_mask = causal_attention_mask else: - attn_weights_reshaped = None + self.is_causal = causal_attention_mask is not None - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped + if not output_attentions: + attn_weights = None + return attn_output, attn_weights # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GitVision diff --git a/src/transformers/models/idefics/vision.py b/src/transformers/models/idefics/vision.py index 5e9f9b8ad7..5b2ef5ae3e 100644 --- a/src/transformers/models/idefics/vision.py +++ b/src/transformers/models/idefics/vision.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -24,7 +24,11 @@ from torch import nn from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling -from ...utils import ModelOutput, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...utils import ( + ModelOutput, + logging, +) from .configuration_idefics import IdeficsVisionConfig @@ -160,11 +164,33 @@ class IdeficsVisionEmbeddings(nn.Module): return embeddings -# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->IdeficsVision +# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + class IdeficsVisionAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config): + def __init__(self, config: IdeficsVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -177,15 +203,13 @@ class IdeficsVisionAttention(nn.Module): ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout + self.is_causal = False self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -195,74 +219,51 @@ class IdeficsVisionAttention(nn.Module): ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" - bsz, tgt_len, embed_dim = hidden_states.size() + batch_size, seq_length, embed_dim = hidden_states.shape - # get query proj - query_states = self.q_proj(hidden_states) * self.scale - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - # apply the causal_attention_mask first - if causal_attention_mask is not None: - if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" - f" {causal_attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + # CLIP text model uses both `causal_attention_mask` and `attention_mask` + # in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask` + if self.config._attn_implementation != "flash_attention_2": + if attention_mask is not None and causal_attention_mask is not None: + attention_mask = attention_mask + causal_attention_mask + elif causal_attention_mask is not None: + attention_mask = causal_attention_mask else: - attn_weights_reshaped = None + self.is_causal = causal_attention_mask is not None - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped + if not output_attentions: + attn_weights = None + return attn_output, attn_weights # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->IdeficsVision diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index 23a1391fa1..e557293ee3 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -31,7 +31,7 @@ from ...modeling_outputs import ( BaseModelOutputWithPooling, CausalLMOutputWithCrossAttentions, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ( ModelOutput, add_start_docstrings, @@ -466,7 +466,29 @@ class Kosmos2VisionEmbeddings(nn.Module): return embeddings -# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->Kosmos2Vision +# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + class Kosmos2VisionAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -483,15 +505,13 @@ class Kosmos2VisionAttention(nn.Module): ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout + self.is_causal = False self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -501,74 +521,51 @@ class Kosmos2VisionAttention(nn.Module): ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" - bsz, tgt_len, embed_dim = hidden_states.size() + batch_size, seq_length, embed_dim = hidden_states.shape - # get query proj - query_states = self.q_proj(hidden_states) * self.scale - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - # apply the causal_attention_mask first - if causal_attention_mask is not None: - if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" - f" {causal_attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + # CLIP text model uses both `causal_attention_mask` and `attention_mask` + # in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask` + if self.config._attn_implementation != "flash_attention_2": + if attention_mask is not None and causal_attention_mask is not None: + attention_mask = attention_mask + causal_attention_mask + elif causal_attention_mask is not None: + attention_mask = causal_attention_mask else: - attn_weights_reshaped = None + self.is_causal = causal_attention_mask is not None - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped + if not output_attentions: + attn_weights = None + return attn_output, attn_weights # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Kosmos2Vision diff --git a/src/transformers/models/mlcd/modeling_mlcd.py b/src/transformers/models/mlcd/modeling_mlcd.py index 574537a7ad..851cd01a1a 100644 --- a/src/transformers/models/mlcd/modeling_mlcd.py +++ b/src/transformers/models/mlcd/modeling_mlcd.py @@ -172,25 +172,6 @@ class MLCDVisionEmbeddings(nn.Module): return embeddings -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -217,6 +198,25 @@ def eager_attention_forward( return attn_output, attn_weights +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + def apply_rotary_pos_emb_vision( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -253,16 +253,13 @@ class MLCDAttention(nn.Module): ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout + self.is_causal = False self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) self.num_key_value_groups = config.num_key_value_groups - self.is_causal = False - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index f85b4636cd..a6a3655961 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -14,9 +14,9 @@ # limitations under the License. """PyTorch X-CLIP model.""" -from copy import copy +import copy from dataclasses import dataclass -from typing import Any, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -25,7 +25,7 @@ from torch import nn from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ( ModelOutput, add_start_docstrings, @@ -223,7 +223,29 @@ class XCLIPTextEmbeddings(nn.Module): return embeddings -# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->XCLIP +# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + class XCLIPAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -240,15 +262,13 @@ class XCLIPAttention(nn.Module): ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout + self.is_causal = False self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -258,77 +278,54 @@ class XCLIPAttention(nn.Module): ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" - bsz, tgt_len, embed_dim = hidden_states.size() + batch_size, seq_length, embed_dim = hidden_states.shape - # get query proj - query_states = self.q_proj(hidden_states) * self.scale - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - # apply the causal_attention_mask first - if causal_attention_mask is not None: - if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" - f" {causal_attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + # CLIP text model uses both `causal_attention_mask` and `attention_mask` + # in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask` + if self.config._attn_implementation != "flash_attention_2": + if attention_mask is not None and causal_attention_mask is not None: + attention_mask = attention_mask + causal_attention_mask + elif causal_attention_mask is not None: + attention_mask = causal_attention_mask else: - attn_weights_reshaped = None + self.is_causal = causal_attention_mask is not None - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() attn_output = self.out_proj(attn_output) + if not output_attentions: + attn_weights = None - return attn_output, attn_weights_reshaped + return attn_output, attn_weights -# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->XCLIP class XCLIPMLP(nn.Module): def __init__(self, config): super().__init__() @@ -1330,8 +1327,7 @@ class XCLIPModel(XCLIPPreTrainedModel): self.prompts_visual_layernorm = nn.LayerNorm(self.vision_embed_dim, eps=config.vision_config.layer_norm_eps) self.prompts_visual_projection = nn.Parameter(torch.randn(self.vision_embed_dim, self.projection_dim)) - - mit_config = copy(vision_config) + mit_config = copy.copy(vision_config) mit_config.hidden_size = vision_config.mit_hidden_size mit_config.intermediate_size = vision_config.mit_intermediate_size mit_config.num_hidden_layers = vision_config.mit_num_hidden_layers diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py index df5cc14fc4..739ae1f590 100644 --- a/tests/models/clip/test_modeling_clip.py +++ b/tests/models/clip/test_modeling_clip.py @@ -17,7 +17,6 @@ import inspect import os import tempfile import unittest -from typing import Optional import numpy as np import requests @@ -36,14 +35,12 @@ from transformers.testing_utils import ( ) from transformers.utils import ( is_torch_available, - is_torch_bf16_available_on_device, - is_torch_fp16_available_on_device, - is_torch_sdpa_available, is_vision_available, ) from ...test_configuration_common import ConfigTester from ...test_modeling_common import ( + TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, ModelTesterMixin, _config_zero_init, floats_tensor, @@ -67,11 +64,6 @@ if is_torch_available(): CLIPVisionModelWithProjection, ) - -if is_torch_sdpa_available(): - from torch.nn.attention import SDPBackend, sdpa_kernel - - if is_vision_available(): from PIL import Image @@ -170,6 +162,11 @@ class CLIPVisionModelTester: inputs_dict = {"pixel_values": pixel_values} return config, inputs_dict + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) + @require_torch_sdpa + def test_eager_matches_sdpa_inference(self, *args): + return getattr(ModelTesterMixin, self._testMethodName)(self) + class CLIPModelTesterMixin(ModelTesterMixin): """ @@ -178,6 +175,7 @@ class CLIPModelTesterMixin(ModelTesterMixin): different output logits, and are not supposed to be used or tested with padding_side="left". """ + @require_torch_sdpa def test_sdpa_can_dispatch_composite_models(self): for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -186,8 +184,8 @@ class CLIPModelTesterMixin(ModelTesterMixin): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) - # Load the model with SDPA - model_sdpa = model_class.from_pretrained(tmpdirname) + # Load the model with SDPA (it is the default, but we explicit it for clarity) + model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa") model_sdpa = model_sdpa.eval().to(torch_device) # Load model with eager attention @@ -197,180 +195,17 @@ class CLIPModelTesterMixin(ModelTesterMixin): ) model_eager = model_eager.eval().to(torch_device) - # SigLip has one shared cls attr for all models, so we assign both submodels heer - vision_attn = text_attn = "sdpa" if model._supports_sdpa else "eager" - - # `None` as it is the requested one which will be assigned to each sub-config - # Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present) - if hasattr(model_sdpa, "vision_model") and hasattr(model_sdpa, "text_model"): - self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn) - self.assertTrue(model_sdpa.text_model.config._attn_implementation == text_attn) + if hasattr(model_sdpa, "vision_model"): + self.assertTrue(model_sdpa.vision_model.config._attn_implementation == "sdpa") self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager") + + if hasattr(model_sdpa, "text_model"): + self.assertTrue(model_sdpa.text_model.config._attn_implementation == "sdpa") self.assertTrue(model_eager.text_model.config._attn_implementation == "eager") self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") self.assertTrue(model_eager.config._attn_implementation == "eager") - for name, submodule in model_eager.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - raise ValueError("The eager model should not have SDPA attention layers") - - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa and model_sdpa.config.model_type != "falcon": - raise ValueError("The SDPA model should have SDPA attention layers") - - def test_eager_matches_sdpa_inference( - self, - torch_dtype: str, - use_attention_mask_options: tuple[Optional[str], ...] = (None, "left", "right"), - logit_keys: tuple[str, ...] = ("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"), - ): - if not self.all_model_classes[0]._supports_sdpa: - self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") - - if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device): - self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") - - if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device): - self.skipTest( - f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)" - ) - - # Convert to torch dtype - dtypes = { - "float16": torch.float16, - "bfloat16": torch.bfloat16, - "float32": torch.float32, - } - torch_dtype = dtypes[torch_dtype] - - atols = { - torch.float32: 1e-5, - torch.bfloat16: 3e-2, - torch.float16: 5e-3, - } - rtols = { - torch.float32: 1e-4, - torch.bfloat16: 3e-2, - torch.float16: 5e-3, - } - - atol = atols[torch_dtype] - rtol = rtols[torch_dtype] - - def get_mean_reldiff(msg, current_case, x, ref, atol, rtol): - return f"{msg} {current_case}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}" - - for model_class in self.all_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - - # Load the model with SDPA - model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) - model_sdpa = model_sdpa.eval().to(torch_device) - - # Load model with eager attention - model_eager = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch_dtype, - attn_implementation="eager", - ) - model_eager = model_eager.eval().to(torch_device) - - # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving the model each time, - # but it would be nicer to have an efficient way to use parameterized.expand - cases = [ - (use_mask, output_attentions, sdpa_backend, batch_size) - for use_mask in use_attention_mask_options - for output_attentions in [True, False] - for sdpa_backend in [ - [SDPBackend.MATH], - [SDPBackend.FLASH_ATTENTION, SDPBackend.MATH], - [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH], - [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH], - ] - for batch_size in [1, 5] - ] - fail_cases = [] - - for use_mask, output_attentions, sdpa_backend, batch_size in cases: - processed_inputs = inputs_dict.copy() - - # convert to torch_dtype - if "pixel_values" in processed_inputs: - processed_inputs["pixel_values"] = processed_inputs["pixel_values"].to(torch_dtype) - - # slice for different batch sizes - for key in ["pixel_values", "input_ids", "attention_mask"]: - if key in processed_inputs: - processed_inputs[key] = processed_inputs[key][:batch_size] - - # set attention mask with left padding - if not use_mask: - processed_inputs.pop("attention_mask", None) - elif use_mask == "left": - dummy_attention_mask = processed_inputs["attention_mask"] - dummy_attention_mask[:] = 1 - dummy_attention_mask[:, :1] = 0 - processed_inputs["attention_mask"] = dummy_attention_mask - elif use_mask == "right": - dummy_attention_mask = processed_inputs["attention_mask"] - dummy_attention_mask[:] = 1 - dummy_attention_mask[:, -1:] = 0 - processed_inputs["attention_mask"] = dummy_attention_mask - else: - raise ValueError(f"Invalid value for use_mask={use_mask}") - - processed_inputs["output_attentions"] = output_attentions - processed_inputs["output_hidden_states"] = True - - current_case = f"use_mask={use_mask}, batch_size={batch_size}, sdpa_backend={sdpa_backend}" - - prepared_inputs = self._prepare_for_class(processed_inputs, model_class) - - with torch.no_grad(): - try: - with sdpa_kernel(sdpa_backend): - outputs_eager = model_eager(**prepared_inputs) - outputs_sdpa = model_sdpa(**prepared_inputs) - except Exception as e: - fail_cases.append(f"{current_case}: {e}") - continue - - keys = set(logit_keys) & set(outputs_eager.keys()) - self.assertTrue( - keys, f"Keys {logit_keys} not found in outputs. Available keys: {outputs_eager.keys()}" - ) - - for key in keys: - try: - eager_logits = outputs_eager[key] - sdpa_logits = outputs_sdpa[key] - except KeyError: - raise KeyError(f"Key {key} not found in outputs. Available keys: {outputs_eager.keys()}") - - if "hidden_state" in key and use_mask == "left": - eager_logits = eager_logits[:, 1:] - sdpa_logits = sdpa_logits[:, 1:] - elif "hidden_state" in key and use_mask == "right": - eager_logits = eager_logits[:, :-1] - sdpa_logits = sdpa_logits[:, :-1] - - is_close = torch.allclose(eager_logits, sdpa_logits, atol=atol, rtol=rtol) - if not is_close: - fail_cases.append(get_mean_reldiff(key, current_case, sdpa_logits, eager_logits, atol, rtol)) - - self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) - @require_torch class CLIPVisionModelTest(CLIPModelTesterMixin, unittest.TestCase): @@ -458,16 +293,12 @@ class CLIPVisionModelTest(CLIPModelTesterMixin, unittest.TestCase): self.assertIsNotNone(model) self.assertTrue(hasattr(model, "visual_projection")) - @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) @require_torch_sdpa - @slow @is_flaky() - def test_eager_matches_sdpa_inference(self, torch_dtype: str): - super().test_eager_matches_sdpa_inference( - torch_dtype=torch_dtype, - logit_keys=("last_hidden_state", "pooler_output", "image_embeds"), - use_attention_mask_options=(None,), - ) + def test_eager_matches_sdpa_inference(self, *args): + # adding only flaky decorator here and call the parent test method + return getattr(ModelTesterMixin, self._testMethodName)(self) @require_torch_sdpa def test_sdpa_can_dispatch_composite_models(self): @@ -632,16 +463,13 @@ class CLIPTextModelTest(CLIPModelTesterMixin, unittest.TestCase): self.assertIsNotNone(model) self.assertTrue(hasattr(model, "text_projection")) - @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) @require_torch_sdpa @slow @is_flaky() - def test_eager_matches_sdpa_inference(self, torch_dtype: str): - super().test_eager_matches_sdpa_inference( - torch_dtype=torch_dtype, - logit_keys=("last_hidden_state", "pooler_output", "text_embeds"), - use_attention_mask_options=(None, "right"), # "left" is not supported for text model - ) + def test_eager_matches_sdpa_inference(self, *args): + # adding only flaky decorator here and call the parent test method + return getattr(ModelTesterMixin, self._testMethodName)(self) @require_torch_sdpa def test_sdpa_can_dispatch_composite_models(self): @@ -860,16 +688,13 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase model = CLIPModel.from_pretrained(model_name) self.assertIsNotNone(model) - @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) @require_torch_sdpa @slow @is_flaky() - def test_eager_matches_sdpa_inference(self, torch_dtype: str): - super().test_eager_matches_sdpa_inference( - torch_dtype=torch_dtype, - logit_keys=("logits_per_image", "logits_per_text"), - use_attention_mask_options=(None, "right"), # "left" is not supported for text model - ) + def test_eager_matches_sdpa_inference(self, *args): + # adding only flaky decorator here and call the parent test method + return getattr(ModelTesterMixin, self._testMethodName)(self) @require_torch_sdpa def test_sdpa_can_dispatch_composite_models(self): @@ -1033,16 +858,13 @@ class CLIPForImageClassificationModelTest(CLIPModelTesterMixin, PipelineTesterMi def test_initialization(self): pass - @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) @require_torch_sdpa @slow @is_flaky() - def test_eager_matches_sdpa_inference(self, torch_dtype: str): - super().test_eager_matches_sdpa_inference( - torch_dtype=torch_dtype, - logit_keys=("logits",), - use_attention_mask_options=(None,), - ) + def test_eager_matches_sdpa_inference(self, *args): + # adding only flaky decorator here and call the parent test method + return getattr(ModelTesterMixin, self._testMethodName)(self) @require_torch_sdpa def test_sdpa_can_dispatch_composite_models(self): @@ -1062,7 +884,7 @@ class CLIPModelIntegrationTest(unittest.TestCase): @slow def test_inference(self): model_name = "openai/clip-vit-base-patch32" - model = CLIPModel.from_pretrained(model_name).to(torch_device) + model = CLIPModel.from_pretrained(model_name, attn_implementation="sdpa").to(torch_device) processor = CLIPProcessor.from_pretrained(model_name) image = prepare_img() @@ -1122,5 +944,5 @@ class CLIPModelIntegrationTest(unittest.TestCase): ).to(torch_device) torch.testing.assert_close( - outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4 + outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, rtol=6e-3, atol=4e-4 )