From 3249c5dc1560dace3c31cdbe4797b6c878ab47de Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 1 Apr 2025 14:37:25 +0100 Subject: [PATCH] Refactor attention for SigLIP based models (#36981) * Update Siglip attention implementation * Update tests for Siglip * Remove one level of indentation * Update test to be more specific * Fixup * Idefics2 * Idefics3 * Emu3 * SmolVLM * Phi4 (just init small update) * Idefics2 (test fix) * Update siglip2 tests * Update eager * trigger * Clean up * Transfer inputs to device in test * Fixing test * Fixing test * Revert contiguous * Remove unused is_flash_attn_2_available * Move flaky to specific models --- src/transformers/models/emu3/modeling_emu3.py | 72 ++-- src/transformers/models/emu3/modular_emu3.py | 9 +- .../models/idefics2/modeling_idefics2.py | 393 ++++-------------- .../models/idefics3/modeling_idefics3.py | 195 +++------ .../modeling_phi4_multimodal.py | 5 +- .../models/siglip/modeling_siglip.py | 252 +++-------- .../models/siglip2/modeling_siglip2.py | 251 +++-------- .../models/smolvlm/modeling_smolvlm.py | 193 +++------ .../models/idefics2/test_modeling_idefics2.py | 8 +- tests/models/siglip/test_modeling_siglip.py | 230 +--------- tests/models/siglip2/test_modeling_siglip2.py | 236 +---------- tests/test_modeling_common.py | 361 ++++++++-------- 12 files changed, 563 insertions(+), 1642 deletions(-) diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 607ac4e6b5..3322ce28f9 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -600,7 +600,7 @@ class Emu3VQVAEResnetBlock(nn.Module): class Emu3VQVAEAttentionBlock(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config): + def __init__(self, config: Emu3VQVAEConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -613,12 +613,16 @@ class Emu3VQVAEAttentionBlock(nn.Module): ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout + self.is_causal = False self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + # for compatibility with the attention interface + self.num_key_value_groups = 1 + def forward( self, hidden_states: torch.Tensor, @@ -627,48 +631,43 @@ class Emu3VQVAEAttentionBlock(nn.Module): ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" - batch_size, q_len, _ = hidden_states.size() + batch_size, seq_length, embed_dim = hidden_states.shape - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) - query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - k_v_seq_len = key_states.shape[-2] - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale - - if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): - raise ValueError( - f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): - raise ValueError( - f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) - attn_weights = attn_weights + attention_mask + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() attn_output = self.out_proj(attn_output) + if not output_attentions: + attn_weights = None + return attn_output, attn_weights @@ -1005,6 +1004,9 @@ class Emu3VQVAE(PreTrainedModel): config_class = Emu3VQVAEConfig base_model_prefix = "emuvideovq" main_input_name = "pixel_values" + _supports_sdpa = True + _supports_flash_attn_2 = True + _supports_flex_attn = True _no_split_modules = [ "Emu3VQVAETemporalResnetBlock", "Emu3VQVAEAttentionBlock", diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index d411ade67a..8af5ec700a 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -394,7 +394,11 @@ class Emu3VQVAEResnetBlock(nn.Module): class Emu3VQVAEAttentionBlock(SiglipAttention): - pass + def __init__(self, config: Emu3VQVAEConfig): + super().__init__(config) + + # for compatibility with the attention interface + self.num_key_value_groups = 1 class Emu3VQVAEGroupNorm(nn.GroupNorm): @@ -730,6 +734,9 @@ class Emu3VQVAE(PreTrainedModel): config_class = Emu3VQVAEConfig base_model_prefix = "emuvideovq" main_input_name = "pixel_values" + _supports_sdpa = True + _supports_flash_attn_2 = True + _supports_flex_attn = True _no_split_modules = [ "Emu3VQVAETemporalResnetBlock", "Emu3VQVAEAttentionBlock", diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 100c59b2af..5ee9ec4cc7 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -14,9 +14,8 @@ # limitations under the License. """PyTorch Idefics2 model.""" -import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -27,9 +26,8 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutput, ModelOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -41,10 +39,6 @@ from ..auto import AutoModel from .configuration_idefics2 import Idefics2Config, Idefics2PerceiverConfig, Idefics2VisionConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "Idefics2Config" @@ -185,6 +179,33 @@ class Idefics2VisionEmbeddings(nn.Module): return embeddings +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + if hasattr(module, "num_key_value_groups"): + key = repeat_kv(key, module.num_key_value_groups) + value = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.siglip.modeling_siglip.SiglipAttention with Siglip->Idefics2Vision class Idefics2VisionAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -220,140 +241,38 @@ class Idefics2VisionAttention(nn.Module): ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" - batch_size, q_len, _ = hidden_states.size() + batch_size, seq_length, embed_dim = hidden_states.shape - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) - query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - k_v_seq_len = key_states.shape[-2] - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale - - if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): - raise ValueError( - f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): - raise ValueError( - f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights - - -class Idefics2VisionFlashAttention2(Idefics2VisionAttention): - """ - Idefics2Vision flash attention module. This module inherits from `Idefics2VisionAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (Idefics2VisionRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype else: - target_dtype = self.q_proj.weight.dtype + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, attention_mask, - q_len, - dropout=dropout_rate, is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, ) - attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous() + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() attn_output = self.out_proj(attn_output) if not output_attentions: @@ -362,12 +281,6 @@ class Idefics2VisionFlashAttention2(Idefics2VisionAttention): return attn_output, attn_weights -IDEFICS_VISION_ATTENTION_CLASSES = { - "eager": Idefics2VisionAttention, - "flash_attention_2": Idefics2VisionFlashAttention2, -} - - # Copied from transformers.models.siglip.modeling_siglip.SiglipMLP with Siglip->Idefics2Vision class Idefics2VisionMLP(nn.Module): def __init__(self, config): @@ -437,7 +350,7 @@ class Idefics2EncoderLayer(nn.Module): def __init__(self, config: Idefics2VisionConfig): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = IDEFICS_VISION_ATTENTION_CLASSES[config._attn_implementation](config) + self.self_attn = Idefics2VisionAttention(config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = Idefics2VisionMLP(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -600,6 +513,7 @@ class Idefics2PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True def _init_weights(self, module): @@ -646,8 +560,10 @@ IDEFICS2_INPUTS_DOCSTRING = r""" IDEFICS2_START_DOCSTRING, ) class Idefics2VisionTransformer(Idefics2PreTrainedModel): - _supports_sdpa = False config_class = Idefics2VisionConfig + _supports_sdpa = True + _supports_flash_attention_2 = True + _supports_flex_attn = True def __init__(self, config: Idefics2VisionConfig): super().__init__(config) @@ -761,7 +677,7 @@ class Idefics2PerceiverAttention(nn.Module): def __init__(self, config, layer_idx: Optional[int] = None) -> None: """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`""" super().__init__() - + self.config = config self.layer_idx = None self.hidden_size = config.hidden_size self.num_heads = config.resampler_n_heads @@ -769,6 +685,7 @@ class Idefics2PerceiverAttention(nn.Module): self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.attention_dropout = config.attention_dropout + self.scaling = self.head_dim**-0.5 self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) @@ -804,179 +721,41 @@ class Idefics2PerceiverAttention(nn.Module): hidden_states = torch.concat([context, latents], dim=-2) - query_states = self.q_proj(latents) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + queries = self.q_proj(latents) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + queries = queries.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + values = values.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) + keys, values = past_key_value.update(keys, values, self.layer_idx) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) - - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# NO LONGER EXIST Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with MistralAttention->Idefics2PerceiverAttention,MistralFlashAttention->Idefics2PerceiverFlashAttention,Mistral->Idefics2 -# TODO cyril: modular -class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention): - """ - Idefics2 flash attention module. This module inherits from `Idefics2PerceiverAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - # Ignore copy - def forward( - self, - latents: torch.Tensor, - context: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = latents.size() - kv_seq_len = q_len + context.size()[1] - - # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn! - # Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents` - query_states = self.q_proj(latents) - key_states = self.k_proj(torch.cat([context, latents], dim=-2)) - value_states = self.v_proj(torch.cat([context, latents], dim=-2)) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - if past_key_value is not None: - # Activate slicing cache only if the config has a value `sliding_windows` attribute - if hasattr(self.config, "sliding_window") and kv_seq_len > self.config.sliding_window: - slicing_tokens = kv_seq_len - self.config.sliding_window - - past_key = past_key_value[0] - past_value = past_key_value[1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - "past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1," - f" head_dim`), got {past_key.shape}" - ) - - past_key_value = (past_key, past_value) - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype else: - target_dtype = self.q_proj.weight.dtype + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, attention_mask, - q_len, - dropout=dropout_rate, - sliding_window=None, is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, ) - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) attn_output = self.o_proj(attn_output) if not output_attentions: @@ -985,12 +764,6 @@ class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention): return attn_output, attn_weights, past_key_value -IDEFICS2_PERCEIVER_ATTENTION_CLASSES = { - "eager": Idefics2PerceiverAttention, - "flash_attention_2": Idefics2PerceiverFlashAttention2, -} - - class Idefics2PerceiverLayer(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() @@ -1001,7 +774,7 @@ class Idefics2PerceiverLayer(nn.Module): self.input_latents_norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps) self.input_context_norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps) - self.self_attn = IDEFICS2_PERCEIVER_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + self.self_attn = Idefics2PerceiverAttention(config, layer_idx=layer_idx) self.post_attention_layernorm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps) self.mlp = Idefics2MLP( hidden_size=config.hidden_size, @@ -1084,8 +857,10 @@ IDEFICS2_INPUTS_DOCSTRING = r""" IDEFICS2_START_DOCSTRING, ) class Idefics2PerceiverResampler(Idefics2PreTrainedModel): - _supports_sdpa = False config_class = Idefics2PerceiverConfig + _supports_sdpa = True + _supports_flash_attention_2 = True + _supports_flex_attn = True def __init__(self, config) -> None: super().__init__(config) diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 3ca17360c6..3821bd3e7a 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -15,7 +15,7 @@ """PyTorch Idefics3 model.""" from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -23,12 +23,11 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache +from ...cache_utils import DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutput, ModelOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -39,10 +38,6 @@ from ..auto import AutoModel from .configuration_idefics3 import Idefics3Config, Idefics3VisionConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "Idefics3Config" @@ -184,6 +179,30 @@ class Idefics3VisionEmbeddings(nn.Module): return embeddings +# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.siglip.modeling_siglip.SiglipAttention with Siglip->Idefics3Vision class Idefics3VisionAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -219,141 +238,38 @@ class Idefics3VisionAttention(nn.Module): ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" - batch_size, q_len, _ = hidden_states.size() + batch_size, seq_length, embed_dim = hidden_states.shape - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) - query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - k_v_seq_len = key_states.shape[-2] - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale - - if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): - raise ValueError( - f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): - raise ValueError( - f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights - - -# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionFlashAttention2 with Idefics2->Idefics3 -class Idefics3VisionFlashAttention2(Idefics3VisionAttention): - """ - Idefics3Vision flash attention module. This module inherits from `Idefics3VisionAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (Idefics3VisionRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype else: - target_dtype = self.q_proj.weight.dtype + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, attention_mask, - q_len, - dropout=dropout_rate, is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, ) - attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous() + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() attn_output = self.out_proj(attn_output) if not output_attentions: @@ -362,12 +278,6 @@ class Idefics3VisionFlashAttention2(Idefics3VisionAttention): return attn_output, attn_weights -IDEFICS_VISION_ATTENTION_CLASSES = { - "eager": Idefics3VisionAttention, - "flash_attention_2": Idefics3VisionFlashAttention2, -} - - # Copied from transformers.models.siglip.modeling_siglip.SiglipMLP with Siglip->Idefics3Vision class Idefics3VisionMLP(nn.Module): def __init__(self, config): @@ -400,7 +310,7 @@ class Idefics3EncoderLayer(nn.Module): def __init__(self, config: Idefics3VisionConfig): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = IDEFICS_VISION_ATTENTION_CLASSES[config._attn_implementation](config) + self.self_attn = Idefics3VisionAttention(config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = Idefics3VisionMLP(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -620,6 +530,7 @@ class Idefics3PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2PreTrainedModel._init_weights @@ -666,7 +577,9 @@ IDEFICS3_VISION_START_DOCSTRING = r""" ) class Idefics3VisionTransformer(Idefics3PreTrainedModel): config_class = Idefics3VisionConfig - _supports_sdpa = False + _supports_sdpa = True + _supports_flash_attention_2 = True + _supports_flex_attn = True def __init__(self, config: Idefics3VisionConfig): super().__init__(config) diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index 5754459529..ea1a7384f5 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -150,12 +150,11 @@ class Phi4MultimodalVisionEncoderLayer(nn.Module): def __init__(self, config: Phi4MultimodalVisionConfig): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = Phi4MultimodalVisionAttention(config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.mlp = Phi4MultimodalVisionMLP(config) + self.self_attn = Phi4MultimodalVisionAttention(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Phi4MultimodalVisionMLP(config) - # Ignore copy def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index 0288b78381..f60feb2ea8 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -17,7 +17,7 @@ import math import warnings from dataclasses import dataclass -from typing import Any, Optional, Tuple +from typing import Any, Callable, Optional, Tuple, Union import numpy as np import torch @@ -28,9 +28,8 @@ from torch.nn.init import _calculate_fan_in_and_fan_out from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ( ModelOutput, add_start_docstrings, @@ -43,10 +42,6 @@ from ...utils import ( from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) # General docstring @@ -360,11 +355,33 @@ class SiglipTextEmbeddings(nn.Module): return embeddings +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class SiglipAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ - def __init__(self, config): + def __init__(self, config: Union[SiglipVisionConfig, SiglipTextConfig]): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -377,6 +394,7 @@ class SiglipAttention(nn.Module): ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout + self.is_causal = False self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) @@ -391,130 +409,38 @@ class SiglipAttention(nn.Module): ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" - batch_size, q_len, _ = hidden_states.size() + batch_size, seq_length, embed_dim = hidden_states.shape - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) - query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - k_v_seq_len = key_states.shape[-2] - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale - - if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): - raise ValueError( - f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): - raise ValueError( - f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights - - -class SiglipFlashAttention2(SiglipAttention): - """ - SiglipAttention flash attention module. This module inherits from `SiglipAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - is_causal = False - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - batch_size, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim) - value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim) - - dropout_rate = self.dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype else: - target_dtype = self.q_proj.weight.dtype + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, attention_mask, - q_len, - dropout=dropout_rate, is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, ) - attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() attn_output = self.out_proj(attn_output) if not output_attentions: @@ -523,79 +449,6 @@ class SiglipFlashAttention2(SiglipAttention): return attn_output, attn_weights -class SiglipSdpaAttention(SiglipAttention): - """ - Siglip attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `SiglipAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - is_causal = False - - # Adapted from SiglipAttention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "SiglipModel is using SiglipSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - - batch_size, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if self.is_causal and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(batch_size, q_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None - - -SIGLIP_ATTENTION_CLASSES = { - "eager": SiglipAttention, - "flash_attention_2": SiglipFlashAttention2, - "sdpa": SiglipSdpaAttention, -} - - # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip class SiglipMLP(nn.Module): def __init__(self, config): @@ -613,15 +466,14 @@ class SiglipMLP(nn.Module): class SiglipEncoderLayer(nn.Module): - def __init__(self, config: SiglipConfig): + def __init__(self, config: Union[SiglipVisionConfig, SiglipTextConfig]): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = SIGLIP_ATTENTION_CLASSES[config._attn_implementation](config=config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.mlp = SiglipMLP(config) + self.self_attn = SiglipAttention(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) - # Ignore copy def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/siglip2/modeling_siglip2.py b/src/transformers/models/siglip2/modeling_siglip2.py index 663d9f3dd0..15d9cffe7d 100644 --- a/src/transformers/models/siglip2/modeling_siglip2.py +++ b/src/transformers/models/siglip2/modeling_siglip2.py @@ -21,7 +21,7 @@ import math import warnings from dataclasses import dataclass -from typing import Any, Optional, Tuple +from typing import Any, Callable, Optional, Tuple, Union import numpy as np import torch @@ -32,9 +32,8 @@ from torch.nn.init import _calculate_fan_in_and_fan_out from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ( ModelOutput, add_start_docstrings, @@ -46,10 +45,6 @@ from ...utils import ( from .configuration_siglip2 import Siglip2Config, Siglip2TextConfig, Siglip2VisionConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) # General docstring @@ -252,10 +247,33 @@ class Siglip2VisionEmbeddings(nn.Module): return embeddings +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class Siglip2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config): + def __init__(self, config: Union[Siglip2VisionConfig, Siglip2TextConfig]): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -268,6 +286,7 @@ class Siglip2Attention(nn.Module): ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout + self.is_causal = False self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) @@ -282,130 +301,38 @@ class Siglip2Attention(nn.Module): ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" - batch_size, q_len, _ = hidden_states.size() + batch_size, seq_length, embed_dim = hidden_states.shape - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) - query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - k_v_seq_len = key_states.shape[-2] - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale - - if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): - raise ValueError( - f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): - raise ValueError( - f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights - - -class Siglip2FlashAttention2(Siglip2Attention): - """ - Siglip2Attention flash attention module. This module inherits from `Siglip2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - is_causal = False - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - batch_size, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim) - value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim) - - dropout_rate = self.dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype else: - target_dtype = self.q_proj.weight.dtype + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, attention_mask, - q_len, - dropout=dropout_rate, is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, ) - attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() attn_output = self.out_proj(attn_output) if not output_attentions: @@ -414,72 +341,6 @@ class Siglip2FlashAttention2(Siglip2Attention): return attn_output, attn_weights -class Siglip2SdpaAttention(Siglip2Attention): - """ - Siglip2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Siglip2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - is_causal = False - - # Adapted from Siglip2Attention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Siglip2Model is using Siglip2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - - batch_size, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if self.is_causal and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(batch_size, q_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None - - class Siglip2MLP(nn.Module): def __init__(self, config): super().__init__() @@ -495,23 +356,15 @@ class Siglip2MLP(nn.Module): return hidden_states -SIGLIP2_ATTENTION_CLASSES = { - "eager": Siglip2Attention, - "flash_attention_2": Siglip2FlashAttention2, - "sdpa": Siglip2SdpaAttention, -} - - class Siglip2EncoderLayer(nn.Module): - def __init__(self, config: Siglip2Config): + def __init__(self, config: Union[Siglip2VisionConfig, Siglip2TextConfig]): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = SIGLIP2_ATTENTION_CLASSES[config._attn_implementation](config=config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.mlp = Siglip2MLP(config) + self.self_attn = Siglip2Attention(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Siglip2MLP(config) - # Ignore copy def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index 466cc9e21a..f9cad6fa20 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -20,19 +20,18 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache +from ...cache_utils import DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutput, ModelOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -43,10 +42,6 @@ from ..auto import AutoModel from .configuration_smolvlm import SmolVLMConfig, SmolVLMVisionConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "SmolVLMConfig" @@ -81,6 +76,7 @@ class SmolVLMPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True def _init_weights(self, module): @@ -161,6 +157,29 @@ class SmolVLMVisionEmbeddings(nn.Module): return embeddings +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class SmolVLMVisionAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -194,140 +213,38 @@ class SmolVLMVisionAttention(nn.Module): ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" - batch_size, q_len, _ = hidden_states.size() + batch_size, seq_length, embed_dim = hidden_states.shape - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) - query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - k_v_seq_len = key_states.shape[-2] - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale - - if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): - raise ValueError( - f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): - raise ValueError( - f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights - - -class SmolVLMVisionFlashAttention2(SmolVLMVisionAttention): - """ - SmolVLMVision flash attention module. This module inherits from `SmolVLMVisionAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (SmolVLMVisionRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype else: - target_dtype = self.q_proj.weight.dtype + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, attention_mask, - q_len, - dropout=dropout_rate, is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, ) - attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous() + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() attn_output = self.out_proj(attn_output) if not output_attentions: @@ -351,17 +268,11 @@ class SmolVLMVisionMLP(nn.Module): return hidden_states -IDEFICS_VISION_ATTENTION_CLASSES = { - "eager": SmolVLMVisionAttention, - "flash_attention_2": SmolVLMVisionFlashAttention2, -} - - class SmolVLMEncoderLayer(nn.Module): def __init__(self, config: SmolVLMVisionConfig): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = IDEFICS_VISION_ATTENTION_CLASSES[config._attn_implementation](config) + self.self_attn = SmolVLMVisionAttention(config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SmolVLMVisionMLP(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -516,7 +427,9 @@ SMOLVLM_VISION_START_DOCSTRING = r""" ) class SmolVLMVisionTransformer(SmolVLMPreTrainedModel): config_class = SmolVLMVisionConfig - _supports_sdpa = False + _supports_sdpa = True + _supports_flash_attention_2 = True + _supports_flex_attn = True def __init__(self, config: SmolVLMVisionConfig): super().__init__(config) diff --git a/tests/models/idefics2/test_modeling_idefics2.py b/tests/models/idefics2/test_modeling_idefics2.py index 56df4bb801..98a0b2d50b 100644 --- a/tests/models/idefics2/test_modeling_idefics2.py +++ b/tests/models/idefics2/test_modeling_idefics2.py @@ -344,17 +344,15 @@ class Idefics2ModelTest(ModelTesterMixin, unittest.TestCase): model_sdpa = model_class.from_pretrained(tmpdirname) model_sdpa = model_sdpa.eval().to(torch_device) - vision_attn = None if model.vision_model._supports_sdpa else "eager" - perceiver_attn = None if model.connector.perceiver_resampler._supports_sdpa else "eager" self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn) - self.assertTrue(model_sdpa.connector.perceiver_resampler.config._attn_implementation == perceiver_attn) + self.assertTrue(model_sdpa.vision_model.config._attn_implementation == "sdpa") + self.assertTrue(model_sdpa.connector.perceiver_resampler.config._attn_implementation == "sdpa") model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") model_eager = model_eager.eval().to(torch_device) self.assertTrue(model_eager.config._attn_implementation == "eager") self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager") - self.assertTrue(model_sdpa.connector.perceiver_resampler.config._attn_implementation == "eager") + self.assertTrue(model_eager.connector.perceiver_resampler.config._attn_implementation == "eager") for name, submodule in model_eager.named_modules(): class_name = submodule.__class__.__name__ diff --git a/tests/models/siglip/test_modeling_siglip.py b/tests/models/siglip/test_modeling_siglip.py index 93268bbc8e..10d57a5aeb 100644 --- a/tests/models/siglip/test_modeling_siglip.py +++ b/tests/models/siglip/test_modeling_siglip.py @@ -18,7 +18,6 @@ import inspect import os import tempfile import unittest -from typing import Tuple import numpy as np import requests @@ -27,30 +26,28 @@ from pytest import mark from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig from transformers.testing_utils import ( + is_flaky, require_flash_attn, require_torch, require_torch_gpu, - require_torch_sdpa, require_vision, slow, torch_device, ) from transformers.utils import ( is_torch_available, - is_torch_bf16_available_on_device, - is_torch_fp16_available_on_device, - is_torch_sdpa_available, is_vision_available, ) from ...test_configuration_common import ConfigTester from ...test_modeling_common import ( + TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, - is_flaky, random_attention_mask, + require_torch_sdpa, ) from ...test_pipeline_mixin import PipelineTesterMixin @@ -61,9 +58,6 @@ if is_torch_available(): from transformers import SiglipForImageClassification, SiglipModel, SiglipTextModel, SiglipVisionModel -if is_torch_sdpa_available(): - from torch.nn.attention import SDPBackend, sdpa_kernel - if is_vision_available(): from PIL import Image @@ -71,6 +65,7 @@ if is_vision_available(): class SiglipModelTesterMixin(ModelTesterMixin): + @require_torch_sdpa def test_sdpa_can_dispatch_composite_models(self): for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -81,171 +76,24 @@ class SiglipModelTesterMixin(ModelTesterMixin): # Load the model with SDPA model_sdpa = model_class.from_pretrained(tmpdirname) - model_sdpa = model_sdpa.eval().to(torch_device) # Load model with eager attention model_eager = model_class.from_pretrained( tmpdirname, attn_implementation="eager", ) - model_eager = model_eager.eval().to(torch_device) - # SigLip has one shared cls attr for all models, so we assign both submodels heer - vision_attn = text_attn = "sdpa" if model._supports_sdpa else "eager" - - if hasattr(model_sdpa, "vision_model") and hasattr(model_sdpa, "text_model"): - self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn) - self.assertTrue(model_sdpa.text_model.config._attn_implementation == text_attn) + if hasattr(model_sdpa, "vision_model"): + self.assertTrue(model_sdpa.vision_model.config._attn_implementation == "sdpa") self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager") + + if hasattr(model_sdpa, "text_model"): + self.assertTrue(model_sdpa.text_model.config._attn_implementation == "sdpa") self.assertTrue(model_eager.text_model.config._attn_implementation == "eager") self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") self.assertTrue(model_eager.config._attn_implementation == "eager") - for name, submodule in model_eager.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - raise ValueError("The eager model should not have SDPA attention layers") - - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa and model_sdpa.config.model_type != "falcon": - raise ValueError("The SDPA model should have SDPA attention layers") - - def test_eager_matches_sdpa_inference( - self, - torch_dtype: str, - use_attention_mask_options: Tuple[bool, ...] = (True, False), - logit_keys: Tuple[str, ...] = ("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"), - ): - if not self.all_model_classes[0]._supports_sdpa: - self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") - - if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device): - self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") - - if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device): - self.skipTest( - f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)" - ) - - # Convert to torch dtype - dtypes = { - "float16": torch.float16, - "bfloat16": torch.bfloat16, - "float32": torch.float32, - } - torch_dtype = dtypes[torch_dtype] - - atols = { - torch.float32: 1e-5, - torch.bfloat16: 3e-2, - torch.float16: 5e-3, - } - rtols = { - torch.float32: 1e-4, - torch.bfloat16: 3e-2, - torch.float16: 5e-3, - } - - atol = atols[torch_dtype] - rtol = rtols[torch_dtype] - - def get_mean_reldiff(msg, current_case, x, ref, atol, rtol): - return f"{msg} {current_case}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}" - - for model_class in self.all_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - - # Load the model with SDPA - model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) - model_sdpa = model_sdpa.eval().to(torch_device) - - # Load model with eager attention - model_eager = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch_dtype, - attn_implementation="eager", - ) - model_eager = model_eager.eval().to(torch_device) - - # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving the model each time, - # but it would be nicer to have an efficient way to use parameterized.expand - cases = [ - (use_mask, output_attentions, sdpa_backend, batch_size) - for use_mask in use_attention_mask_options - for output_attentions in [True, False] - for sdpa_backend in [ - SDPBackend.MATH, - [SDPBackend.FLASH_ATTENTION, SDPBackend.MATH], - [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH], - [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH], - ] - for batch_size in [1, 5] - ] - fail_cases = [] - - for use_mask, output_attentions, sdpa_backend, batch_size in cases: - processed_inputs = inputs_dict.copy() - - # convert to torch_dtype - if "pixel_values" in processed_inputs: - processed_inputs["pixel_values"] = processed_inputs["pixel_values"].to(torch_dtype) - - # slice for different batch sizes - for key in ["pixel_values", "input_ids", "attention_mask"]: - if key in processed_inputs: - processed_inputs[key] = processed_inputs[key][:batch_size] - - # set attention mask with left padding - if not use_mask: - processed_inputs.pop("attention_mask", None) - else: - dummy_attention_mask = processed_inputs["attention_mask"] - dummy_attention_mask[:] = 1 - dummy_attention_mask[:, :1] = 0 - processed_inputs["attention_mask"] = dummy_attention_mask - - processed_inputs["output_attentions"] = output_attentions - processed_inputs["output_hidden_states"] = True - - current_case = ( - f"padding_side=left, use_mask={use_mask}, batch_size={batch_size}, sdpa_backend={sdpa_backend}" - ) - - prepared_inputs = self._prepare_for_class(processed_inputs, model_class) - - with torch.no_grad(): - try: - with sdpa_kernel(sdpa_backend): - outputs_eager = model_eager(**prepared_inputs) - outputs_sdpa = model_sdpa(**prepared_inputs) - except Exception as e: - fail_cases.append(f"{current_case}: {e}") - continue - - for key in logit_keys: - eager_logits = outputs_eager[key] - sdpa_logits = outputs_sdpa[key] - - if use_mask: - eager_logits = eager_logits[:, 1:] - sdpa_logits = sdpa_logits[:, 1:] - - is_close = torch.allclose(eager_logits, sdpa_logits, atol=atol, rtol=rtol) - if not is_close: - fail_cases.append(get_mean_reldiff(key, current_case, sdpa_logits, eager_logits, atol, rtol)) - - self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) - class SiglipVisionModelTester: def __init__( @@ -409,20 +257,12 @@ class SiglipVisionModelTest(SiglipModelTesterMixin, unittest.TestCase): model = SiglipVisionModel.from_pretrained(model_name) self.assertIsNotNone(model) - @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) @require_torch_sdpa - @slow @is_flaky() - def test_eager_matches_sdpa_inference(self, torch_dtype: str): - super().test_eager_matches_sdpa_inference( - torch_dtype=torch_dtype, - logit_keys=("pooler_output", "last_hidden_state"), - use_attention_mask_options=(False,), - ) - - @require_torch_sdpa - def test_sdpa_can_dispatch_composite_models(self): - super().test_sdpa_can_dispatch_composite_models() + def test_eager_matches_sdpa_inference(self, *args): + # adding only flaky decorator here and call the parent test method + return getattr(ModelTesterMixin, self._testMethodName)(self) class SiglipTextModelTester: @@ -565,21 +405,6 @@ class SiglipTextModelTest(SiglipModelTesterMixin, unittest.TestCase): model = SiglipTextModel.from_pretrained(model_name) self.assertIsNotNone(model) - @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) - @require_torch_sdpa - @slow - @is_flaky() - def test_eager_matches_sdpa_inference(self, torch_dtype: str): - super().test_eager_matches_sdpa_inference( - torch_dtype=torch_dtype, - logit_keys=("pooler_output", "last_hidden_state"), - use_attention_mask_options=(False, True), - ) - - @require_torch_sdpa - def test_sdpa_can_dispatch_composite_models(self): - super().test_sdpa_can_dispatch_composite_models() - class SiglipModelTester: def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True): @@ -634,6 +459,7 @@ class SiglipModelTester: @require_torch class SiglipModelTest(SiglipModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + additional_model_inputs = ["pixel_values"] all_model_classes = (SiglipModel,) if is_torch_available() else () pipeline_model_mapping = {"feature-extraction": SiglipModel} if is_torch_available() else {} fx_compatible = False @@ -862,21 +688,6 @@ class SiglipModelTest(SiglipModelTesterMixin, PipelineTesterMixin, unittest.Test def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("SigLIP does not support right padding") - @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) - @require_torch_sdpa - @slow - @is_flaky() - def test_eager_matches_sdpa_inference(self, torch_dtype: str): - super().test_eager_matches_sdpa_inference( - torch_dtype=torch_dtype, - logit_keys=("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"), - use_attention_mask_options=(False, True), - ) - - @require_torch_sdpa - def test_sdpa_can_dispatch_composite_models(self): - super().test_sdpa_can_dispatch_composite_models() - class SiglipForImageClassificationModelTester(SiglipModelTester): def __init__(self, parent): @@ -943,19 +754,6 @@ class SiglipForImageClassificationModelTest(SiglipModelTesterMixin, PipelineTest def test_initialization(self): pass - @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) - @require_torch_sdpa - @slow - @is_flaky() - def test_eager_matches_sdpa_inference(self, torch_dtype: str): - super().test_eager_matches_sdpa_inference( - torch_dtype=torch_dtype, logit_keys=("logits",), use_attention_mask_options=(False,) - ) - - @require_torch_sdpa - def test_sdpa_can_dispatch_composite_models(self): - super().test_sdpa_can_dispatch_composite_models() - # We will verify our results on an image of cute cats def prepare_img(): diff --git a/tests/models/siglip2/test_modeling_siglip2.py b/tests/models/siglip2/test_modeling_siglip2.py index f5959edb5f..885f57751b 100644 --- a/tests/models/siglip2/test_modeling_siglip2.py +++ b/tests/models/siglip2/test_modeling_siglip2.py @@ -17,7 +17,6 @@ import inspect import tempfile import unittest -from typing import Tuple import numpy as np from parameterized import parameterized @@ -25,29 +24,27 @@ from pytest import mark from transformers import Siglip2Config, Siglip2TextConfig, Siglip2VisionConfig from transformers.testing_utils import ( + is_flaky, require_flash_attn, require_torch, require_torch_gpu, - require_torch_sdpa, require_vision, slow, torch_device, ) from transformers.utils import ( is_torch_available, - is_torch_bf16_available_on_device, - is_torch_fp16_available_on_device, - is_torch_sdpa_available, is_vision_available, ) from ...test_configuration_common import ConfigTester from ...test_modeling_common import ( + TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, ModelTesterMixin, floats_tensor, ids_tensor, - is_flaky, random_attention_mask, + require_torch_sdpa, ) from ...test_pipeline_mixin import PipelineTesterMixin @@ -58,9 +55,6 @@ if is_torch_available(): from transformers import Siglip2ForImageClassification, Siglip2Model, Siglip2TextModel, Siglip2VisionModel -if is_torch_sdpa_available(): - from torch.nn.attention import SDPBackend, sdpa_kernel - if is_vision_available(): from PIL import Image, ImageDraw @@ -68,6 +62,7 @@ if is_vision_available(): class Siglip2ModelTesterMixin(ModelTesterMixin): + @require_torch_sdpa def test_sdpa_can_dispatch_composite_models(self): for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -78,171 +73,24 @@ class Siglip2ModelTesterMixin(ModelTesterMixin): # Load the model with SDPA model_sdpa = model_class.from_pretrained(tmpdirname) - model_sdpa = model_sdpa.eval().to(torch_device) # Load model with eager attention model_eager = model_class.from_pretrained( tmpdirname, attn_implementation="eager", ) - model_eager = model_eager.eval().to(torch_device) - # SigLip has one shared cls attr for all models, so we assign both submodels heer - vision_attn = text_attn = "sdpa" if model._supports_sdpa else "eager" - - if hasattr(model_sdpa, "vision_model") and hasattr(model_sdpa, "text_model"): - self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn) - self.assertTrue(model_sdpa.text_model.config._attn_implementation == text_attn) + if hasattr(model_sdpa, "vision_model"): + self.assertTrue(model_sdpa.vision_model.config._attn_implementation == "sdpa") self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager") + + if hasattr(model_sdpa, "text_model"): + self.assertTrue(model_sdpa.text_model.config._attn_implementation == "sdpa") self.assertTrue(model_eager.text_model.config._attn_implementation == "eager") self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") self.assertTrue(model_eager.config._attn_implementation == "eager") - for name, submodule in model_eager.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - raise ValueError("The eager model should not have SDPA attention layers") - - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa and model_sdpa.config.model_type != "falcon": - raise ValueError("The SDPA model should have SDPA attention layers") - - def test_eager_matches_sdpa_inference( - self, - torch_dtype: str, - use_attention_mask_options: Tuple[bool, ...] = (True, False), - logit_keys: Tuple[str, ...] = ("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"), - ): - if not self.all_model_classes[0]._supports_sdpa: - self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") - - if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device): - self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") - - if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device): - self.skipTest( - f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)" - ) - - # Convert to torch dtype - dtypes = { - "float16": torch.float16, - "bfloat16": torch.bfloat16, - "float32": torch.float32, - } - torch_dtype = dtypes[torch_dtype] - - atols = { - torch.float32: 1e-5, - torch.bfloat16: 3e-2, - torch.float16: 5e-3, - } - rtols = { - torch.float32: 1e-4, - torch.bfloat16: 3e-2, - torch.float16: 5e-3, - } - - atol = atols[torch_dtype] - rtol = rtols[torch_dtype] - - def get_mean_reldiff(msg, current_case, x, ref, atol, rtol): - return f"{msg} {current_case}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}" - - for model_class in self.all_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - - # Load the model with SDPA - model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) - model_sdpa = model_sdpa.eval().to(torch_device) - - # Load model with eager attention - model_eager = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch_dtype, - attn_implementation="eager", - ) - model_eager = model_eager.eval().to(torch_device) - - # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving the model each time, - # but it would be nicer to have an efficient way to use parameterized.expand - cases = [ - (use_mask, output_attentions, sdpa_backend, batch_size) - for use_mask in use_attention_mask_options - for output_attentions in [True, False] - for sdpa_backend in [ - SDPBackend.MATH, - [SDPBackend.FLASH_ATTENTION, SDPBackend.MATH], - [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH], - [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH], - ] - for batch_size in [1, 5] - ] - fail_cases = [] - - for use_mask, output_attentions, sdpa_backend, batch_size in cases: - processed_inputs = inputs_dict.copy() - - # convert to torch_dtype - if "pixel_values" in processed_inputs: - processed_inputs["pixel_values"] = processed_inputs["pixel_values"].to(torch_dtype) - - # slice for different batch sizes - for key in processed_inputs.keys(): - if isinstance(processed_inputs[key], (torch.Tensor, list, tuple)): - processed_inputs[key] = processed_inputs[key][:batch_size] - - # set attention mask with left padding - if not use_mask: - processed_inputs.pop("attention_mask", None) - else: - dummy_attention_mask = processed_inputs["attention_mask"] - dummy_attention_mask[:] = 1 - dummy_attention_mask[:, :1] = 0 - processed_inputs["attention_mask"] = dummy_attention_mask - - processed_inputs["output_attentions"] = output_attentions - processed_inputs["output_hidden_states"] = True - - current_case = ( - f"padding_side=left, use_mask={use_mask}, batch_size={batch_size}, sdpa_backend={sdpa_backend}" - ) - - prepared_inputs = self._prepare_for_class(processed_inputs, model_class) - - with torch.no_grad(): - try: - with sdpa_kernel(sdpa_backend): - outputs_eager = model_eager(**prepared_inputs) - outputs_sdpa = model_sdpa(**prepared_inputs) - except Exception as e: - fail_cases.append(f"{current_case}: {e}") - continue - - for key in logit_keys: - eager_logits = outputs_eager[key] - sdpa_logits = outputs_sdpa[key] - - if use_mask: - eager_logits = eager_logits[:, 1:] - sdpa_logits = sdpa_logits[:, 1:] - - is_close = torch.allclose(eager_logits, sdpa_logits, atol=atol, rtol=rtol) - if not is_close: - fail_cases.append(get_mean_reldiff(key, current_case, sdpa_logits, eager_logits, atol, rtol)) - - self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) - @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -422,6 +270,7 @@ class Siglip2VisionModelTest(Siglip2ModelTesterMixin, unittest.TestCase): """ all_model_classes = (Siglip2VisionModel,) if is_torch_available() else () + additional_model_inputs = ["pixel_attention_mask", "spatial_shapes"] fx_compatible = False test_pruning = False test_resize_embeddings = False @@ -497,20 +346,12 @@ class Siglip2VisionModelTest(Siglip2ModelTesterMixin, unittest.TestCase): model = Siglip2VisionModel.from_pretrained(model_name) self.assertIsNotNone(model) - @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) @require_torch_sdpa - @slow @is_flaky() - def test_eager_matches_sdpa_inference(self, torch_dtype: str): - super().test_eager_matches_sdpa_inference( - torch_dtype=torch_dtype, - logit_keys=("pooler_output", "last_hidden_state"), - use_attention_mask_options=(False,), - ) - - @require_torch_sdpa - def test_sdpa_can_dispatch_composite_models(self): - super().test_sdpa_can_dispatch_composite_models() + def test_eager_matches_sdpa_inference(self, *args): + # adding only flaky decorator here and call the parent test method + return getattr(ModelTesterMixin, self._testMethodName)(self) class Siglip2TextModelTester: @@ -648,21 +489,6 @@ class Siglip2TextModelTest(Siglip2ModelTesterMixin, unittest.TestCase): model = Siglip2TextModel.from_pretrained(model_name) self.assertIsNotNone(model) - @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) - @require_torch_sdpa - @slow - @is_flaky() - def test_eager_matches_sdpa_inference(self, torch_dtype: str): - super().test_eager_matches_sdpa_inference( - torch_dtype=torch_dtype, - logit_keys=("pooler_output", "last_hidden_state"), - use_attention_mask_options=(False, True), - ) - - @require_torch_sdpa - def test_sdpa_can_dispatch_composite_models(self): - super().test_sdpa_can_dispatch_composite_models() - class Siglip2ModelTester: def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True): @@ -725,6 +551,11 @@ class Siglip2ModelTester: class Siglip2ModelTest(Siglip2ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (Siglip2Model,) if is_torch_available() else () pipeline_model_mapping = {"feature-extraction": Siglip2Model} if is_torch_available() else {} + additional_model_inputs = [ + "pixel_values", + "pixel_attention_mask", + "spatial_shapes", + ] fx_compatible = False test_head_masking = False test_pruning = False @@ -796,21 +627,6 @@ class Siglip2ModelTest(Siglip2ModelTesterMixin, PipelineTesterMixin, unittest.Te def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Siglip2 does not support right padding") - @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) - @require_torch_sdpa - @slow - @is_flaky() - def test_eager_matches_sdpa_inference(self, torch_dtype: str): - super().test_eager_matches_sdpa_inference( - torch_dtype=torch_dtype, - logit_keys=("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"), - use_attention_mask_options=(False, True), - ) - - @require_torch_sdpa - def test_sdpa_can_dispatch_composite_models(self): - super().test_sdpa_can_dispatch_composite_models() - class Siglip2ForImageClassificationModelTester(Siglip2ModelTester): def __init__(self, parent): @@ -841,6 +657,7 @@ class Siglip2ForImageClassificationModelTester(Siglip2ModelTester): class Siglip2ForImageClassificationModelTest(Siglip2ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (Siglip2ForImageClassification,) if is_torch_available() else () pipeline_model_mapping = {"image-classification": Siglip2ForImageClassification} if is_torch_available() else {} + additional_model_inputs = ["pixel_values", "pixel_attention_mask", "spatial_shapes"] fx_compatible = False test_head_masking = False test_pruning = False @@ -881,19 +698,6 @@ class Siglip2ForImageClassificationModelTest(Siglip2ModelTesterMixin, PipelineTe def test_initialization(self): pass - @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) - @require_torch_sdpa - @slow - @is_flaky() - def test_eager_matches_sdpa_inference(self, torch_dtype: str): - super().test_eager_matches_sdpa_inference( - torch_dtype=torch_dtype, logit_keys=("logits",), use_attention_mask_options=(False,) - ) - - @require_torch_sdpa - def test_sdpa_can_dispatch_composite_models(self): - super().test_sdpa_can_dispatch_composite_models() - # Draw a circle on an images with different aspect ratios def prepare_images(): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a8e6f0a434..ce8921b333 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3457,7 +3457,7 @@ class ModelTesterMixin: ): # TODO: we shouldn't need to do this skip, i.e. the test would be composable from the model tester. CLIP-like # models have a custom mixin, which we detect to skip this test. - if not any(".ModelTesterMixin" in str(base) for base in self.__class__.__bases__): + if any(".CLIPModelTesterMixin" in str(base) for base in self.__class__.__bases__): self.skipTest(reason="CLIP-like models have a different `test_eager_matches_sdpa_inference`") if not self.has_attentions: @@ -3549,206 +3549,213 @@ class ModelTesterMixin: model_eager = model_class.from_pretrained(**model_from_pretrained_kwargs, attn_implementation="eager") model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype) - set_model_for_less_flaky_test(model_eager) - set_model_for_less_flaky_test(model_sdpa) + set_model_for_less_flaky_test(model_eager) + set_model_for_less_flaky_test(model_sdpa) - can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters - if not (self.has_attentions and can_output_attn) and output_attentions: - self.skipTest(reason="Model does not support output_attentions") + can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters + if not (self.has_attentions and can_output_attn) and output_attentions: + self.skipTest(reason="Model does not support output_attentions") - # TODO: if we can also check with `batch_size=1` without being flaky? - for batch_size in [7]: - # musicgen decoder models; TODO: find better abstraction - if hasattr(self.model_tester, "num_codebooks") and not hasattr(model_eager, "text_encoder"): + # TODO: if we can also check with `batch_size=1` without being flaky? + for batch_size in [7]: + # musicgen decoder models; TODO: find better abstraction + if hasattr(self.model_tester, "num_codebooks") and not hasattr(model_eager, "text_encoder"): + input_data_batch_size = batch_size * self.model_tester.num_codebooks + else: + input_data_batch_size = batch_size + + processed_inputs = {} + processed_inputs[model.main_input_name] = inputs_dict[model.main_input_name] + + for key in getattr(self, "additional_model_inputs", []): + processed_inputs[key] = inputs_dict[key] + + for key, value in processed_inputs.items(): + if torch.is_floating_point(value): + value = value.to(torch_dtype) + + # extend value to have at least `input_data_batch_size` elements + if value.shape[0] < input_data_batch_size: + size = (input_data_batch_size - value.shape[0], *value.shape[1:]) + if torch.is_floating_point(value): + extension = torch.rand(size=size, dtype=value.dtype, device=torch_device) + else: + extension = torch.randint(high=5, size=size, dtype=value.dtype, device=torch_device) + value = torch.cat((value, extension), dim=0).to(torch_device) + + processed_inputs[key] = value[:input_data_batch_size] + + if not use_attention_mask: + dummy_attention_mask = None + else: + dummy_attention_mask = inputs_dict.get("attention_mask", None) + if dummy_attention_mask is None: + if is_encoder_decoder: + seqlen = inputs_dict.get( + "decoder_input_ids", processed_inputs[model.main_input_name] + ).shape[-1] + else: + seqlen = processed_inputs[model.main_input_name].shape[-1] + dummy_attention_mask = torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device) + + # extend dummy_attention_mask to have at least `batch_size` elements + if dummy_attention_mask.shape[0] < batch_size: + size = (batch_size - dummy_attention_mask.shape[0], *dummy_attention_mask.shape[1:]) + extension = torch.ones(size=size, dtype=dummy_attention_mask.dtype, device=torch_device) + dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0) + + dummy_attention_mask = dummy_attention_mask[:batch_size].to(torch_device) + + dummy_attention_mask[:] = 1 + if padding_side == "left": + dummy_attention_mask[-1, :2] = 0 + dummy_attention_mask[-1, 2:] = 1 + elif padding_side == "right": + dummy_attention_mask[-1, -2:] = 0 + dummy_attention_mask[-1, :-2] = 1 + + if is_encoder_decoder: + # musicgen encoder-decoder models; TODO: find better abstraction + if hasattr(self.model_tester, "num_codebooks"): input_data_batch_size = batch_size * self.model_tester.num_codebooks else: input_data_batch_size = batch_size - dummy_input = inputs_dict[model.main_input_name] + decoder_input_ids = inputs_dict.get("decoder_input_ids", processed_inputs[model.main_input_name]) + decoder_input_ids = decoder_input_ids[:input_data_batch_size] + if decoder_input_ids.shape[0] != input_data_batch_size: + extension = torch.ones( + input_data_batch_size - decoder_input_ids.shape[0], + *decoder_input_ids.shape[1:], + dtype=decoder_input_ids.dtype, + device=torch_device, + ) + decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0) + decoder_input_ids = decoder_input_ids.to(torch_device) - if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: - dummy_input = dummy_input.to(torch_dtype) - - dummy_input = dummy_input[:input_data_batch_size] - if dummy_input.shape[0] != input_data_batch_size: - if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: - extension = torch.rand( - input_data_batch_size - dummy_input.shape[0], - *dummy_input.shape[1:], - dtype=torch_dtype, - device=torch_device, - ) - dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) - else: - extension = torch.randint( - high=5, - size=(input_data_batch_size - dummy_input.shape[0], *dummy_input.shape[1:]), - dtype=dummy_input.dtype, - device=torch_device, - ) - dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) - - if not use_attention_mask: - dummy_attention_mask = None - else: - dummy_attention_mask = inputs_dict.get("attention_mask", None) - if dummy_attention_mask is None: - if is_encoder_decoder: - seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1] - else: - seqlen = dummy_input.shape[-1] - dummy_attention_mask = torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device) - - dummy_attention_mask = dummy_attention_mask[:batch_size] - if dummy_attention_mask.shape[0] != batch_size: - extension = torch.ones( - batch_size - dummy_attention_mask.shape[0], - *dummy_attention_mask.shape[1:], - dtype=dummy_attention_mask.dtype, - device=torch_device, - ) - dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0) - dummy_attention_mask = dummy_attention_mask.to(torch_device) - - dummy_attention_mask[:] = 1 - if padding_side == "left": - dummy_attention_mask[-1, :2] = 0 - dummy_attention_mask[-1, 2:] = 1 - elif padding_side == "right": - dummy_attention_mask[-1, -2:] = 0 - dummy_attention_mask[-1, :-2] = 1 - - if is_encoder_decoder: - # musicgen encoder-decoder models; TODO: find better abstraction - if hasattr(self.model_tester, "num_codebooks"): - input_data_batch_size = batch_size * self.model_tester.num_codebooks - else: - input_data_batch_size = batch_size - - decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:input_data_batch_size] - if decoder_input_ids.shape[0] != input_data_batch_size: - extension = torch.ones( - input_data_batch_size - decoder_input_ids.shape[0], - *decoder_input_ids.shape[1:], - dtype=decoder_input_ids.dtype, - device=torch_device, - ) - decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0) - decoder_input_ids = decoder_input_ids.to(torch_device) - - # TODO: never an `attention_mask` arg here? - processed_inputs = { - model.main_input_name: dummy_input, + # TODO: never an `attention_mask` arg here? + processed_inputs.update( + { "decoder_input_ids": decoder_input_ids, "decoder_attention_mask": dummy_attention_mask, "output_hidden_states": True, } - else: - processed_inputs = { - model.main_input_name: dummy_input, + ) + else: + processed_inputs.update( + { "output_hidden_states": True, } + ) - # Otherwise fails for e.g. WhisperEncoderModel - if "attention_mask" in inspect.signature(model_eager.forward).parameters: - processed_inputs["attention_mask"] = dummy_attention_mask + # Otherwise fails for e.g. WhisperEncoderModel + if "attention_mask" in inspect.signature(model_eager.forward).parameters: + processed_inputs["attention_mask"] = dummy_attention_mask - if ( - self.has_attentions - and "output_attentions" in inspect.signature(model_sdpa.forward).parameters - ): - processed_inputs["output_attentions"] = output_attentions - if "bool_masked_pos" in inspect.signature(model_eager.forward).parameters: - dummy_mask = torch.ones((self.model_tester.num_masks,)) + if self.has_attentions and "output_attentions" in inspect.signature(model_sdpa.forward).parameters: + processed_inputs["output_attentions"] = output_attentions + if "bool_masked_pos" in inspect.signature(model_eager.forward).parameters: + dummy_mask = torch.ones((self.model_tester.num_masks,)) - # In case of additional token (like class) we define a custom `mask_length` - if hasattr(self.model_tester, "mask_length"): - mask_length = self.model_tester.mask_length - dummy_mask.size(0) - else: - mask_length = self.model_tester.seq_length - dummy_mask.size(0) - dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)]) - dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool() - processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device) - - if "noise" in inspect.signature(model_eager.forward).parameters: - np.random.seed(2) - num_patches = int((self.model_tester.image_size // self.model_tester.patch_size) ** 2) - noise = np.random.uniform(size=(batch_size, num_patches)) - processed_inputs["noise"] = torch.from_numpy(noise) - - # TODO: test gradients as well (& for FA2 as well!) - with torch.no_grad(): - with sdpa_kernel( - enable_flash=enable_kernels, - enable_math=True, - enable_mem_efficient=enable_kernels, - ): - prepared_inputs = self._prepare_for_class(processed_inputs, model_class) - outputs_eager = model_eager(**prepared_inputs) - outputs_sdpa = model_sdpa(**prepared_inputs) - - # TODO: rename logits -> hidden_states - if hasattr(outputs_eager, "vision_hidden_states"): - logits_eager = outputs_eager.vision_hidden_states[-1] - logits_sdpa = outputs_sdpa.vision_hidden_states[-1] - elif hasattr(outputs_eager, "audio_values"): - logits_eager = outputs_eager.audio_values - logits_sdpa = outputs_sdpa.audio_values + # In case of additional token (like class) we define a custom `mask_length` + if hasattr(self.model_tester, "mask_length"): + mask_length = self.model_tester.mask_length - dummy_mask.size(0) else: - logits_eager = ( - outputs_eager.decoder_hidden_states[-1] - if hasattr(outputs_eager, "decoder_hidden_states") - else outputs_eager.hidden_states[-1] - ) - logits_sdpa = ( - outputs_sdpa.decoder_hidden_states[-1] - if hasattr(outputs_sdpa, "decoder_hidden_states") - else outputs_sdpa.hidden_states[-1] - ) + mask_length = self.model_tester.seq_length - dummy_mask.size(0) + dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)]) + dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool() + processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device) - if torch_device in ["cpu", "cuda"]: - atol = atols[torch_device, enable_kernels, torch_dtype] - rtol = rtols[torch_device, enable_kernels, torch_dtype] - elif torch_device == "xpu": - # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH - # which is implemented on PyTorch level using aten operators and is - # device agnostic with respect to implementation of each aten operator. - atol = atols["cuda", False, torch_dtype] - rtol = rtols["cuda", False, torch_dtype] - else: - atol = 1e-7 - rtol = 1e-4 + if "noise" in inspect.signature(model_eager.forward).parameters: + np.random.seed(2) + num_patches = int((self.model_tester.image_size // self.model_tester.patch_size) ** 2) + noise = np.random.uniform(size=(batch_size, num_patches)) + processed_inputs["noise"] = torch.from_numpy(noise) - # Masked tokens output slightly deviates - we don't mind that. - if use_attention_mask: - _logits_sdpa = torch.zeros_like(input=logits_sdpa) - _logits_eager = torch.zeros_like(input=logits_eager) + # TODO: test gradients as well (& for FA2 as well!) + with torch.no_grad(): + with sdpa_kernel( + enable_flash=enable_kernels, + enable_math=True, + enable_mem_efficient=enable_kernels, + ): + prepared_inputs = self._prepare_for_class(processed_inputs, model_class) + prepared_inputs = { + k: v.to(torch_device) if isinstance(v, torch.Tensor) else v + for k, v in prepared_inputs.items() + } + outputs_eager = model_eager(**prepared_inputs) + outputs_sdpa = model_sdpa(**prepared_inputs) - _logits_sdpa[:-1] = logits_sdpa[:-1] - _logits_eager[:-1] = logits_eager[:-1] + if "logits_per_text" in outputs_eager: + key = "logits_per_text" + elif "vision_hidden_states" in outputs_eager: + key = "vision_hidden_states" + elif "audio_values" in outputs_eager: + key = "audio_values" + elif "decoder_hidden_states" in outputs_eager: + key = "decoder_hidden_states" + elif "logits" in outputs_eager and "Classification" in model_class.__name__: + key = "logits" + else: + key = "hidden_states" - if padding_side == "left": - _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:] - _logits_eager[-1:, 2:] = logits_eager[-1:, 2:] + # TODO: rename logits -> hidden_states + logits_eager = outputs_eager[key] + logits_sdpa = outputs_sdpa[key] - elif padding_side == "right": - _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2] - _logits_eager[-1:, 2:] = logits_eager[-1:, :-2] + if key in ["vision_hidden_states", "decoder_hidden_states", "hidden_states"]: + logits_eager = logits_eager[-1] + logits_sdpa = logits_sdpa[-1] - logits_sdpa = _logits_sdpa - logits_eager = _logits_eager + if key == "logits_per_text": + nan_mask = torch.isnan(logits_eager) + logits_eager[nan_mask] = 0 + logits_sdpa[nan_mask] = 0 - results = [ - torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol) - for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager) - ] - # If 80% batch elements have matched results, it's fine - if np.mean(results) < 0.8: - mean_relative_diff = ((logits_sdpa - logits_eager).abs() / (logits_eager.abs() + 1e-12)).mean() - raise ValueError( - f"mean relative difference: {mean_relative_diff:.3e}, torch atol = {atol}, torch rtol = " - f"{rtol}" - ) + if torch_device in ["cpu", "cuda"]: + atol = atols[torch_device, enable_kernels, torch_dtype] + rtol = rtols[torch_device, enable_kernels, torch_dtype] + elif torch_device == "xpu": + # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH + # which is implemented on PyTorch level using aten operators and is + # device agnostic with respect to implementation of each aten operator. + atol = atols["cuda", False, torch_dtype] + rtol = rtols["cuda", False, torch_dtype] + else: + atol = 1e-7 + rtol = 1e-4 + + # Masked tokens output slightly deviates - we don't mind that. + if use_attention_mask: + _logits_sdpa = torch.zeros_like(input=logits_sdpa) + _logits_eager = torch.zeros_like(input=logits_eager) + + _logits_sdpa[:-1] = logits_sdpa[:-1] + _logits_eager[:-1] = logits_eager[:-1] + + if padding_side == "left": + _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:] + _logits_eager[-1:, 2:] = logits_eager[-1:, 2:] + + elif padding_side == "right": + _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2] + _logits_eager[-1:, 2:] = logits_eager[-1:, :-2] + + logits_sdpa = _logits_sdpa + logits_eager = _logits_eager + + results = [ + torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol) + for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager) + ] + # If 80% batch elements have matched results, it's fine + if np.mean(results) < 0.8: + mean_relative_diff = ((logits_sdpa - logits_eager).abs() / (logits_eager.abs() + 1e-12)).mean() + raise ValueError( + f"mean relative difference for {key}: {mean_relative_diff:.3e}, torch atol = {atol}, torch rtol = " + f"{rtol}" + ) @require_torch_sdpa @require_torch_gpu