Refactor attention for SigLIP based models (#36981)

* Update Siglip attention implementation

* Update tests for Siglip

* Remove one level of indentation

* Update test to be more specific

* Fixup

* Idefics2

* Idefics3

* Emu3

* SmolVLM

* Phi4 (just init small update)

* Idefics2 (test fix)

* Update siglip2 tests

* Update eager

* trigger

* Clean up

* Transfer inputs to device in test

* Fixing test

* Fixing test

* Revert contiguous

* Remove unused is_flash_attn_2_available

* Move flaky to specific models
This commit is contained in:
Pavel Iakubovskii
2025-04-01 14:37:25 +01:00
committed by GitHub
parent 24e311f42b
commit 3249c5dc15
12 changed files with 563 additions and 1642 deletions

View File

@@ -600,7 +600,7 @@ class Emu3VQVAEResnetBlock(nn.Module):
class Emu3VQVAEAttentionBlock(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config):
def __init__(self, config: Emu3VQVAEConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
@@ -613,12 +613,16 @@ class Emu3VQVAEAttentionBlock(nn.Module):
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.is_causal = False
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
# for compatibility with the attention interface
self.num_key_value_groups = 1
def forward(
self,
hidden_states: torch.Tensor,
@@ -627,48 +631,43 @@ class Emu3VQVAEAttentionBlock(nn.Module):
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
batch_size, q_len, _ = hidden_states.size()
batch_size, seq_length, embed_dim = hidden_states.shape
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
raise ValueError(
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
attn_weights = attn_weights + attention_mask
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output, attn_weights = attention_interface(
self,
queries,
keys,
values,
attention_mask,
is_causal=self.is_causal,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
)
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights
@@ -1005,6 +1004,9 @@ class Emu3VQVAE(PreTrainedModel):
config_class = Emu3VQVAEConfig
base_model_prefix = "emuvideovq"
main_input_name = "pixel_values"
_supports_sdpa = True
_supports_flash_attn_2 = True
_supports_flex_attn = True
_no_split_modules = [
"Emu3VQVAETemporalResnetBlock",
"Emu3VQVAEAttentionBlock",

View File

@@ -394,7 +394,11 @@ class Emu3VQVAEResnetBlock(nn.Module):
class Emu3VQVAEAttentionBlock(SiglipAttention):
pass
def __init__(self, config: Emu3VQVAEConfig):
super().__init__(config)
# for compatibility with the attention interface
self.num_key_value_groups = 1
class Emu3VQVAEGroupNorm(nn.GroupNorm):
@@ -730,6 +734,9 @@ class Emu3VQVAE(PreTrainedModel):
config_class = Emu3VQVAEConfig
base_model_prefix = "emuvideovq"
main_input_name = "pixel_values"
_supports_sdpa = True
_supports_flash_attn_2 = True
_supports_flex_attn = True
_no_split_modules = [
"Emu3VQVAETemporalResnetBlock",
"Emu3VQVAEAttentionBlock",

View File

@@ -14,9 +14,8 @@
# limitations under the License.
"""PyTorch Idefics2 model."""
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -27,9 +26,8 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_outputs import BaseModelOutput, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
@@ -41,10 +39,6 @@ from ..auto import AutoModel
from .configuration_idefics2 import Idefics2Config, Idefics2PerceiverConfig, Idefics2VisionConfig
if is_flash_attn_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Idefics2Config"
@@ -185,6 +179,33 @@ class Idefics2VisionEmbeddings(nn.Module):
return embeddings
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
if hasattr(module, "num_key_value_groups"):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.siglip.modeling_siglip.SiglipAttention with Siglip->Idefics2Vision
class Idefics2VisionAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -220,140 +241,38 @@ class Idefics2VisionAttention(nn.Module):
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
batch_size, q_len, _ = hidden_states.size()
batch_size, seq_length, embed_dim = hidden_states.shape
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
raise ValueError(
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class Idefics2VisionFlashAttention2(Idefics2VisionAttention):
"""
Idefics2Vision flash attention module. This module inherits from `Idefics2VisionAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (Idefics2VisionRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attn_output, attn_weights = attention_interface(
self,
queries,
keys,
values,
attention_mask,
q_len,
dropout=dropout_rate,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
)
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
@@ -362,12 +281,6 @@ class Idefics2VisionFlashAttention2(Idefics2VisionAttention):
return attn_output, attn_weights
IDEFICS_VISION_ATTENTION_CLASSES = {
"eager": Idefics2VisionAttention,
"flash_attention_2": Idefics2VisionFlashAttention2,
}
# Copied from transformers.models.siglip.modeling_siglip.SiglipMLP with Siglip->Idefics2Vision
class Idefics2VisionMLP(nn.Module):
def __init__(self, config):
@@ -437,7 +350,7 @@ class Idefics2EncoderLayer(nn.Module):
def __init__(self, config: Idefics2VisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = IDEFICS_VISION_ATTENTION_CLASSES[config._attn_implementation](config)
self.self_attn = Idefics2VisionAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Idefics2VisionMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
@@ -600,6 +513,7 @@ class Idefics2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
def _init_weights(self, module):
@@ -646,8 +560,10 @@ IDEFICS2_INPUTS_DOCSTRING = r"""
IDEFICS2_START_DOCSTRING,
)
class Idefics2VisionTransformer(Idefics2PreTrainedModel):
_supports_sdpa = False
config_class = Idefics2VisionConfig
_supports_sdpa = True
_supports_flash_attention_2 = True
_supports_flex_attn = True
def __init__(self, config: Idefics2VisionConfig):
super().__init__(config)
@@ -761,7 +677,7 @@ class Idefics2PerceiverAttention(nn.Module):
def __init__(self, config, layer_idx: Optional[int] = None) -> None:
"""Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
super().__init__()
self.config = config
self.layer_idx = None
self.hidden_size = config.hidden_size
self.num_heads = config.resampler_n_heads
@@ -769,6 +685,7 @@ class Idefics2PerceiverAttention(nn.Module):
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.attention_dropout = config.attention_dropout
self.scaling = self.head_dim**-0.5
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
@@ -804,179 +721,41 @@ class Idefics2PerceiverAttention(nn.Module):
hidden_states = torch.concat([context, latents], dim=-2)
query_states = self.q_proj(latents)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
queries = self.q_proj(latents)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
queries = queries.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
keys = keys.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
values = values.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
keys, values = past_key_value.update(keys, values, self.layer_idx)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# NO LONGER EXIST Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with MistralAttention->Idefics2PerceiverAttention,MistralFlashAttention->Idefics2PerceiverFlashAttention,Mistral->Idefics2
# TODO cyril: modular
class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention):
"""
Idefics2 flash attention module. This module inherits from `Idefics2PerceiverAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
# Ignore copy
def forward(
self,
latents: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = latents.size()
kv_seq_len = q_len + context.size()[1]
# Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn!
# Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents`
query_states = self.q_proj(latents)
key_states = self.k_proj(torch.cat([context, latents], dim=-2))
value_states = self.v_proj(torch.cat([context, latents], dim=-2))
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
if hasattr(self.config, "sliding_window") and kv_seq_len > self.config.sliding_window:
slicing_tokens = kv_seq_len - self.config.sliding_window
past_key = past_key_value[0]
past_value = past_key_value[1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1,"
f" head_dim`), got {past_key.shape}"
)
past_key_value = (past_key, past_value)
if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
dropout_rate = 0.0 if not self.training else self.attention_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# Reashape to the expected shape for Flash Attention
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attn_output, attn_weights = attention_interface(
self,
queries,
keys,
values,
attention_mask,
q_len,
dropout=dropout_rate,
sliding_window=None,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
)
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
if not output_attentions:
@@ -985,12 +764,6 @@ class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention):
return attn_output, attn_weights, past_key_value
IDEFICS2_PERCEIVER_ATTENTION_CLASSES = {
"eager": Idefics2PerceiverAttention,
"flash_attention_2": Idefics2PerceiverFlashAttention2,
}
class Idefics2PerceiverLayer(nn.Module):
def __init__(self, config, layer_idx: int):
super().__init__()
@@ -1001,7 +774,7 @@ class Idefics2PerceiverLayer(nn.Module):
self.input_latents_norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
self.input_context_norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
self.self_attn = IDEFICS2_PERCEIVER_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
self.self_attn = Idefics2PerceiverAttention(config, layer_idx=layer_idx)
self.post_attention_layernorm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
self.mlp = Idefics2MLP(
hidden_size=config.hidden_size,
@@ -1084,8 +857,10 @@ IDEFICS2_INPUTS_DOCSTRING = r"""
IDEFICS2_START_DOCSTRING,
)
class Idefics2PerceiverResampler(Idefics2PreTrainedModel):
_supports_sdpa = False
config_class = Idefics2PerceiverConfig
_supports_sdpa = True
_supports_flash_attention_2 = True
_supports_flex_attn = True
def __init__(self, config) -> None:
super().__init__(config)

View File

@@ -15,7 +15,7 @@
"""PyTorch Idefics3 model."""
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -23,12 +23,11 @@ from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...cache_utils import DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_outputs import BaseModelOutput, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
@@ -39,10 +38,6 @@ from ..auto import AutoModel
from .configuration_idefics3 import Idefics3Config, Idefics3VisionConfig
if is_flash_attn_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Idefics3Config"
@@ -184,6 +179,30 @@ class Idefics3VisionEmbeddings(nn.Module):
return embeddings
# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.siglip.modeling_siglip.SiglipAttention with Siglip->Idefics3Vision
class Idefics3VisionAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -219,141 +238,38 @@ class Idefics3VisionAttention(nn.Module):
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
batch_size, q_len, _ = hidden_states.size()
batch_size, seq_length, embed_dim = hidden_states.shape
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
raise ValueError(
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionFlashAttention2 with Idefics2->Idefics3
class Idefics3VisionFlashAttention2(Idefics3VisionAttention):
"""
Idefics3Vision flash attention module. This module inherits from `Idefics3VisionAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (Idefics3VisionRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attn_output, attn_weights = attention_interface(
self,
queries,
keys,
values,
attention_mask,
q_len,
dropout=dropout_rate,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
)
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
@@ -362,12 +278,6 @@ class Idefics3VisionFlashAttention2(Idefics3VisionAttention):
return attn_output, attn_weights
IDEFICS_VISION_ATTENTION_CLASSES = {
"eager": Idefics3VisionAttention,
"flash_attention_2": Idefics3VisionFlashAttention2,
}
# Copied from transformers.models.siglip.modeling_siglip.SiglipMLP with Siglip->Idefics3Vision
class Idefics3VisionMLP(nn.Module):
def __init__(self, config):
@@ -400,7 +310,7 @@ class Idefics3EncoderLayer(nn.Module):
def __init__(self, config: Idefics3VisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = IDEFICS_VISION_ATTENTION_CLASSES[config._attn_implementation](config)
self.self_attn = Idefics3VisionAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Idefics3VisionMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
@@ -620,6 +530,7 @@ class Idefics3PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2PreTrainedModel._init_weights
@@ -666,7 +577,9 @@ IDEFICS3_VISION_START_DOCSTRING = r"""
)
class Idefics3VisionTransformer(Idefics3PreTrainedModel):
config_class = Idefics3VisionConfig
_supports_sdpa = False
_supports_sdpa = True
_supports_flash_attention_2 = True
_supports_flex_attn = True
def __init__(self, config: Idefics3VisionConfig):
super().__init__(config)

View File

@@ -150,12 +150,11 @@ class Phi4MultimodalVisionEncoderLayer(nn.Module):
def __init__(self, config: Phi4MultimodalVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = Phi4MultimodalVisionAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Phi4MultimodalVisionMLP(config)
self.self_attn = Phi4MultimodalVisionAttention(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Phi4MultimodalVisionMLP(config)
# Ignore copy
def forward(
self,
hidden_states: torch.Tensor,

View File

@@ -17,7 +17,7 @@
import math
import warnings
from dataclasses import dataclass
from typing import Any, Optional, Tuple
from typing import Any, Callable, Optional, Tuple, Union
import numpy as np
import torch
@@ -28,9 +28,8 @@ from torch.nn.init import _calculate_fan_in_and_fan_out
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import (
ModelOutput,
add_start_docstrings,
@@ -43,10 +42,6 @@ from ...utils import (
from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
if is_flash_attn_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__)
# General docstring
@@ -360,11 +355,33 @@ class SiglipTextEmbeddings(nn.Module):
return embeddings
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class SiglipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
def __init__(self, config):
def __init__(self, config: Union[SiglipVisionConfig, SiglipTextConfig]):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
@@ -377,6 +394,7 @@ class SiglipAttention(nn.Module):
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.is_causal = False
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
@@ -391,130 +409,38 @@ class SiglipAttention(nn.Module):
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
batch_size, q_len, _ = hidden_states.size()
batch_size, seq_length, embed_dim = hidden_states.shape
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
raise ValueError(
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class SiglipFlashAttention2(SiglipAttention):
"""
SiglipAttention flash attention module. This module inherits from `SiglipAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
is_causal = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
# Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
output_attentions = False
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)
dropout_rate = self.dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attn_output, attn_weights = attention_interface(
self,
queries,
keys,
values,
attention_mask,
q_len,
dropout=dropout_rate,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
)
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
@@ -523,79 +449,6 @@ class SiglipFlashAttention2(SiglipAttention):
return attn_output, attn_weights
class SiglipSdpaAttention(SiglipAttention):
"""
Siglip attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`SiglipAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
is_causal = False
# Adapted from SiglipAttention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"SiglipModel is using SiglipSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if self.is_causal and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None
SIGLIP_ATTENTION_CLASSES = {
"eager": SiglipAttention,
"flash_attention_2": SiglipFlashAttention2,
"sdpa": SiglipSdpaAttention,
}
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
class SiglipMLP(nn.Module):
def __init__(self, config):
@@ -613,15 +466,14 @@ class SiglipMLP(nn.Module):
class SiglipEncoderLayer(nn.Module):
def __init__(self, config: SiglipConfig):
def __init__(self, config: Union[SiglipVisionConfig, SiglipTextConfig]):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = SIGLIP_ATTENTION_CLASSES[config._attn_implementation](config=config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
self.self_attn = SiglipAttention(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
# Ignore copy
def forward(
self,
hidden_states: torch.Tensor,

View File

@@ -21,7 +21,7 @@
import math
import warnings
from dataclasses import dataclass
from typing import Any, Optional, Tuple
from typing import Any, Callable, Optional, Tuple, Union
import numpy as np
import torch
@@ -32,9 +32,8 @@ from torch.nn.init import _calculate_fan_in_and_fan_out
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import (
ModelOutput,
add_start_docstrings,
@@ -46,10 +45,6 @@ from ...utils import (
from .configuration_siglip2 import Siglip2Config, Siglip2TextConfig, Siglip2VisionConfig
if is_flash_attn_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__)
# General docstring
@@ -252,10 +247,33 @@ class Siglip2VisionEmbeddings(nn.Module):
return embeddings
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class Siglip2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config):
def __init__(self, config: Union[Siglip2VisionConfig, Siglip2TextConfig]):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
@@ -268,6 +286,7 @@ class Siglip2Attention(nn.Module):
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.is_causal = False
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
@@ -282,130 +301,38 @@ class Siglip2Attention(nn.Module):
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
batch_size, q_len, _ = hidden_states.size()
batch_size, seq_length, embed_dim = hidden_states.shape
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
raise ValueError(
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class Siglip2FlashAttention2(Siglip2Attention):
"""
Siglip2Attention flash attention module. This module inherits from `Siglip2Attention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
is_causal = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
# Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
output_attentions = False
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)
dropout_rate = self.dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attn_output, attn_weights = attention_interface(
self,
queries,
keys,
values,
attention_mask,
q_len,
dropout=dropout_rate,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
)
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
@@ -414,72 +341,6 @@ class Siglip2FlashAttention2(Siglip2Attention):
return attn_output, attn_weights
class Siglip2SdpaAttention(Siglip2Attention):
"""
Siglip2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`Siglip2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
is_causal = False
# Adapted from Siglip2Attention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"Siglip2Model is using Siglip2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if self.is_causal and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None
class Siglip2MLP(nn.Module):
def __init__(self, config):
super().__init__()
@@ -495,23 +356,15 @@ class Siglip2MLP(nn.Module):
return hidden_states
SIGLIP2_ATTENTION_CLASSES = {
"eager": Siglip2Attention,
"flash_attention_2": Siglip2FlashAttention2,
"sdpa": Siglip2SdpaAttention,
}
class Siglip2EncoderLayer(nn.Module):
def __init__(self, config: Siglip2Config):
def __init__(self, config: Union[Siglip2VisionConfig, Siglip2TextConfig]):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = SIGLIP2_ATTENTION_CLASSES[config._attn_implementation](config=config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Siglip2MLP(config)
self.self_attn = Siglip2Attention(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Siglip2MLP(config)
# Ignore copy
def forward(
self,
hidden_states: torch.Tensor,

View File

@@ -20,19 +20,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...cache_utils import DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_outputs import BaseModelOutput, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
@@ -43,10 +42,6 @@ from ..auto import AutoModel
from .configuration_smolvlm import SmolVLMConfig, SmolVLMVisionConfig
if is_flash_attn_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "SmolVLMConfig"
@@ -81,6 +76,7 @@ class SmolVLMPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
def _init_weights(self, module):
@@ -161,6 +157,29 @@ class SmolVLMVisionEmbeddings(nn.Module):
return embeddings
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class SmolVLMVisionAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -194,140 +213,38 @@ class SmolVLMVisionAttention(nn.Module):
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
batch_size, q_len, _ = hidden_states.size()
batch_size, seq_length, embed_dim = hidden_states.shape
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
raise ValueError(
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class SmolVLMVisionFlashAttention2(SmolVLMVisionAttention):
"""
SmolVLMVision flash attention module. This module inherits from `SmolVLMVisionAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (SmolVLMVisionRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attn_output, attn_weights = attention_interface(
self,
queries,
keys,
values,
attention_mask,
q_len,
dropout=dropout_rate,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
)
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
@@ -351,17 +268,11 @@ class SmolVLMVisionMLP(nn.Module):
return hidden_states
IDEFICS_VISION_ATTENTION_CLASSES = {
"eager": SmolVLMVisionAttention,
"flash_attention_2": SmolVLMVisionFlashAttention2,
}
class SmolVLMEncoderLayer(nn.Module):
def __init__(self, config: SmolVLMVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = IDEFICS_VISION_ATTENTION_CLASSES[config._attn_implementation](config)
self.self_attn = SmolVLMVisionAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SmolVLMVisionMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
@@ -516,7 +427,9 @@ SMOLVLM_VISION_START_DOCSTRING = r"""
)
class SmolVLMVisionTransformer(SmolVLMPreTrainedModel):
config_class = SmolVLMVisionConfig
_supports_sdpa = False
_supports_sdpa = True
_supports_flash_attention_2 = True
_supports_flex_attn = True
def __init__(self, config: SmolVLMVisionConfig):
super().__init__(config)

View File

@@ -344,17 +344,15 @@ class Idefics2ModelTest(ModelTesterMixin, unittest.TestCase):
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
vision_attn = None if model.vision_model._supports_sdpa else "eager"
perceiver_attn = None if model.connector.perceiver_resampler._supports_sdpa else "eager"
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model_sdpa.connector.perceiver_resampler.config._attn_implementation == perceiver_attn)
self.assertTrue(model_sdpa.vision_model.config._attn_implementation == "sdpa")
self.assertTrue(model_sdpa.connector.perceiver_resampler.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager")
self.assertTrue(model_sdpa.connector.perceiver_resampler.config._attn_implementation == "eager")
self.assertTrue(model_eager.connector.perceiver_resampler.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__

View File

@@ -18,7 +18,6 @@ import inspect
import os
import tempfile
import unittest
from typing import Tuple
import numpy as np
import requests
@@ -27,30 +26,28 @@ from pytest import mark
from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
from transformers.testing_utils import (
is_flaky,
require_flash_attn,
require_torch,
require_torch_gpu,
require_torch_sdpa,
require_vision,
slow,
torch_device,
)
from transformers.utils import (
is_torch_available,
is_torch_bf16_available_on_device,
is_torch_fp16_available_on_device,
is_torch_sdpa_available,
is_vision_available,
)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
ModelTesterMixin,
_config_zero_init,
floats_tensor,
ids_tensor,
is_flaky,
random_attention_mask,
require_torch_sdpa,
)
from ...test_pipeline_mixin import PipelineTesterMixin
@@ -61,9 +58,6 @@ if is_torch_available():
from transformers import SiglipForImageClassification, SiglipModel, SiglipTextModel, SiglipVisionModel
if is_torch_sdpa_available():
from torch.nn.attention import SDPBackend, sdpa_kernel
if is_vision_available():
from PIL import Image
@@ -71,6 +65,7 @@ if is_vision_available():
class SiglipModelTesterMixin(ModelTesterMixin):
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -81,171 +76,24 @@ class SiglipModelTesterMixin(ModelTesterMixin):
# Load the model with SDPA
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
# Load model with eager attention
model_eager = model_class.from_pretrained(
tmpdirname,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
# SigLip has one shared cls attr for all models, so we assign both submodels heer
vision_attn = text_attn = "sdpa" if model._supports_sdpa else "eager"
if hasattr(model_sdpa, "vision_model") and hasattr(model_sdpa, "text_model"):
self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model_sdpa.text_model.config._attn_implementation == text_attn)
if hasattr(model_sdpa, "vision_model"):
self.assertTrue(model_sdpa.vision_model.config._attn_implementation == "sdpa")
self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager")
if hasattr(model_sdpa, "text_model"):
self.assertTrue(model_sdpa.text_model.config._attn_implementation == "sdpa")
self.assertTrue(model_eager.text_model.config._attn_implementation == "eager")
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
def test_eager_matches_sdpa_inference(
self,
torch_dtype: str,
use_attention_mask_options: Tuple[bool, ...] = (True, False),
logit_keys: Tuple[str, ...] = ("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"),
):
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Convert to torch dtype
dtypes = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
torch_dtype = dtypes[torch_dtype]
atols = {
torch.float32: 1e-5,
torch.bfloat16: 3e-2,
torch.float16: 5e-3,
}
rtols = {
torch.float32: 1e-4,
torch.bfloat16: 3e-2,
torch.float16: 5e-3,
}
atol = atols[torch_dtype]
rtol = rtols[torch_dtype]
def get_mean_reldiff(msg, current_case, x, ref, atol, rtol):
return f"{msg} {current_case}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# Load the model with SDPA
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
# Load model with eager attention
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving the model each time,
# but it would be nicer to have an efficient way to use parameterized.expand
cases = [
(use_mask, output_attentions, sdpa_backend, batch_size)
for use_mask in use_attention_mask_options
for output_attentions in [True, False]
for sdpa_backend in [
SDPBackend.MATH,
[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH],
[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH],
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH],
]
for batch_size in [1, 5]
]
fail_cases = []
for use_mask, output_attentions, sdpa_backend, batch_size in cases:
processed_inputs = inputs_dict.copy()
# convert to torch_dtype
if "pixel_values" in processed_inputs:
processed_inputs["pixel_values"] = processed_inputs["pixel_values"].to(torch_dtype)
# slice for different batch sizes
for key in ["pixel_values", "input_ids", "attention_mask"]:
if key in processed_inputs:
processed_inputs[key] = processed_inputs[key][:batch_size]
# set attention mask with left padding
if not use_mask:
processed_inputs.pop("attention_mask", None)
else:
dummy_attention_mask = processed_inputs["attention_mask"]
dummy_attention_mask[:] = 1
dummy_attention_mask[:, :1] = 0
processed_inputs["attention_mask"] = dummy_attention_mask
processed_inputs["output_attentions"] = output_attentions
processed_inputs["output_hidden_states"] = True
current_case = (
f"padding_side=left, use_mask={use_mask}, batch_size={batch_size}, sdpa_backend={sdpa_backend}"
)
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
with torch.no_grad():
try:
with sdpa_kernel(sdpa_backend):
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
except Exception as e:
fail_cases.append(f"{current_case}: {e}")
continue
for key in logit_keys:
eager_logits = outputs_eager[key]
sdpa_logits = outputs_sdpa[key]
if use_mask:
eager_logits = eager_logits[:, 1:]
sdpa_logits = sdpa_logits[:, 1:]
is_close = torch.allclose(eager_logits, sdpa_logits, atol=atol, rtol=rtol)
if not is_close:
fail_cases.append(get_mean_reldiff(key, current_case, sdpa_logits, eager_logits, atol, rtol))
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
class SiglipVisionModelTester:
def __init__(
@@ -409,20 +257,12 @@ class SiglipVisionModelTest(SiglipModelTesterMixin, unittest.TestCase):
model = SiglipVisionModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@require_torch_sdpa
@slow
@is_flaky()
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
super().test_eager_matches_sdpa_inference(
torch_dtype=torch_dtype,
logit_keys=("pooler_output", "last_hidden_state"),
use_attention_mask_options=(False,),
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
super().test_sdpa_can_dispatch_composite_models()
def test_eager_matches_sdpa_inference(self, *args):
# adding only flaky decorator here and call the parent test method
return getattr(ModelTesterMixin, self._testMethodName)(self)
class SiglipTextModelTester:
@@ -565,21 +405,6 @@ class SiglipTextModelTest(SiglipModelTesterMixin, unittest.TestCase):
model = SiglipTextModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@is_flaky()
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
super().test_eager_matches_sdpa_inference(
torch_dtype=torch_dtype,
logit_keys=("pooler_output", "last_hidden_state"),
use_attention_mask_options=(False, True),
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
super().test_sdpa_can_dispatch_composite_models()
class SiglipModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
@@ -634,6 +459,7 @@ class SiglipModelTester:
@require_torch
class SiglipModelTest(SiglipModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
additional_model_inputs = ["pixel_values"]
all_model_classes = (SiglipModel,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": SiglipModel} if is_torch_available() else {}
fx_compatible = False
@@ -862,21 +688,6 @@ class SiglipModelTest(SiglipModelTesterMixin, PipelineTesterMixin, unittest.Test
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest("SigLIP does not support right padding")
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@is_flaky()
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
super().test_eager_matches_sdpa_inference(
torch_dtype=torch_dtype,
logit_keys=("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"),
use_attention_mask_options=(False, True),
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
super().test_sdpa_can_dispatch_composite_models()
class SiglipForImageClassificationModelTester(SiglipModelTester):
def __init__(self, parent):
@@ -943,19 +754,6 @@ class SiglipForImageClassificationModelTest(SiglipModelTesterMixin, PipelineTest
def test_initialization(self):
pass
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@is_flaky()
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
super().test_eager_matches_sdpa_inference(
torch_dtype=torch_dtype, logit_keys=("logits",), use_attention_mask_options=(False,)
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
super().test_sdpa_can_dispatch_composite_models()
# We will verify our results on an image of cute cats
def prepare_img():

View File

@@ -17,7 +17,6 @@
import inspect
import tempfile
import unittest
from typing import Tuple
import numpy as np
from parameterized import parameterized
@@ -25,29 +24,27 @@ from pytest import mark
from transformers import Siglip2Config, Siglip2TextConfig, Siglip2VisionConfig
from transformers.testing_utils import (
is_flaky,
require_flash_attn,
require_torch,
require_torch_gpu,
require_torch_sdpa,
require_vision,
slow,
torch_device,
)
from transformers.utils import (
is_torch_available,
is_torch_bf16_available_on_device,
is_torch_fp16_available_on_device,
is_torch_sdpa_available,
is_vision_available,
)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
ModelTesterMixin,
floats_tensor,
ids_tensor,
is_flaky,
random_attention_mask,
require_torch_sdpa,
)
from ...test_pipeline_mixin import PipelineTesterMixin
@@ -58,9 +55,6 @@ if is_torch_available():
from transformers import Siglip2ForImageClassification, Siglip2Model, Siglip2TextModel, Siglip2VisionModel
if is_torch_sdpa_available():
from torch.nn.attention import SDPBackend, sdpa_kernel
if is_vision_available():
from PIL import Image, ImageDraw
@@ -68,6 +62,7 @@ if is_vision_available():
class Siglip2ModelTesterMixin(ModelTesterMixin):
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -78,171 +73,24 @@ class Siglip2ModelTesterMixin(ModelTesterMixin):
# Load the model with SDPA
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
# Load model with eager attention
model_eager = model_class.from_pretrained(
tmpdirname,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
# SigLip has one shared cls attr for all models, so we assign both submodels heer
vision_attn = text_attn = "sdpa" if model._supports_sdpa else "eager"
if hasattr(model_sdpa, "vision_model") and hasattr(model_sdpa, "text_model"):
self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model_sdpa.text_model.config._attn_implementation == text_attn)
if hasattr(model_sdpa, "vision_model"):
self.assertTrue(model_sdpa.vision_model.config._attn_implementation == "sdpa")
self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager")
if hasattr(model_sdpa, "text_model"):
self.assertTrue(model_sdpa.text_model.config._attn_implementation == "sdpa")
self.assertTrue(model_eager.text_model.config._attn_implementation == "eager")
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
def test_eager_matches_sdpa_inference(
self,
torch_dtype: str,
use_attention_mask_options: Tuple[bool, ...] = (True, False),
logit_keys: Tuple[str, ...] = ("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"),
):
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Convert to torch dtype
dtypes = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
torch_dtype = dtypes[torch_dtype]
atols = {
torch.float32: 1e-5,
torch.bfloat16: 3e-2,
torch.float16: 5e-3,
}
rtols = {
torch.float32: 1e-4,
torch.bfloat16: 3e-2,
torch.float16: 5e-3,
}
atol = atols[torch_dtype]
rtol = rtols[torch_dtype]
def get_mean_reldiff(msg, current_case, x, ref, atol, rtol):
return f"{msg} {current_case}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# Load the model with SDPA
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
# Load model with eager attention
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving the model each time,
# but it would be nicer to have an efficient way to use parameterized.expand
cases = [
(use_mask, output_attentions, sdpa_backend, batch_size)
for use_mask in use_attention_mask_options
for output_attentions in [True, False]
for sdpa_backend in [
SDPBackend.MATH,
[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH],
[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH],
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH],
]
for batch_size in [1, 5]
]
fail_cases = []
for use_mask, output_attentions, sdpa_backend, batch_size in cases:
processed_inputs = inputs_dict.copy()
# convert to torch_dtype
if "pixel_values" in processed_inputs:
processed_inputs["pixel_values"] = processed_inputs["pixel_values"].to(torch_dtype)
# slice for different batch sizes
for key in processed_inputs.keys():
if isinstance(processed_inputs[key], (torch.Tensor, list, tuple)):
processed_inputs[key] = processed_inputs[key][:batch_size]
# set attention mask with left padding
if not use_mask:
processed_inputs.pop("attention_mask", None)
else:
dummy_attention_mask = processed_inputs["attention_mask"]
dummy_attention_mask[:] = 1
dummy_attention_mask[:, :1] = 0
processed_inputs["attention_mask"] = dummy_attention_mask
processed_inputs["output_attentions"] = output_attentions
processed_inputs["output_hidden_states"] = True
current_case = (
f"padding_side=left, use_mask={use_mask}, batch_size={batch_size}, sdpa_backend={sdpa_backend}"
)
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
with torch.no_grad():
try:
with sdpa_kernel(sdpa_backend):
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
except Exception as e:
fail_cases.append(f"{current_case}: {e}")
continue
for key in logit_keys:
eager_logits = outputs_eager[key]
sdpa_logits = outputs_sdpa[key]
if use_mask:
eager_logits = eager_logits[:, 1:]
sdpa_logits = sdpa_logits[:, 1:]
is_close = torch.allclose(eager_logits, sdpa_logits, atol=atol, rtol=rtol)
if not is_close:
fail_cases.append(get_mean_reldiff(key, current_case, sdpa_logits, eager_logits, atol, rtol))
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -422,6 +270,7 @@ class Siglip2VisionModelTest(Siglip2ModelTesterMixin, unittest.TestCase):
"""
all_model_classes = (Siglip2VisionModel,) if is_torch_available() else ()
additional_model_inputs = ["pixel_attention_mask", "spatial_shapes"]
fx_compatible = False
test_pruning = False
test_resize_embeddings = False
@@ -497,20 +346,12 @@ class Siglip2VisionModelTest(Siglip2ModelTesterMixin, unittest.TestCase):
model = Siglip2VisionModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@require_torch_sdpa
@slow
@is_flaky()
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
super().test_eager_matches_sdpa_inference(
torch_dtype=torch_dtype,
logit_keys=("pooler_output", "last_hidden_state"),
use_attention_mask_options=(False,),
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
super().test_sdpa_can_dispatch_composite_models()
def test_eager_matches_sdpa_inference(self, *args):
# adding only flaky decorator here and call the parent test method
return getattr(ModelTesterMixin, self._testMethodName)(self)
class Siglip2TextModelTester:
@@ -648,21 +489,6 @@ class Siglip2TextModelTest(Siglip2ModelTesterMixin, unittest.TestCase):
model = Siglip2TextModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@is_flaky()
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
super().test_eager_matches_sdpa_inference(
torch_dtype=torch_dtype,
logit_keys=("pooler_output", "last_hidden_state"),
use_attention_mask_options=(False, True),
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
super().test_sdpa_can_dispatch_composite_models()
class Siglip2ModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
@@ -725,6 +551,11 @@ class Siglip2ModelTester:
class Siglip2ModelTest(Siglip2ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Siglip2Model,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": Siglip2Model} if is_torch_available() else {}
additional_model_inputs = [
"pixel_values",
"pixel_attention_mask",
"spatial_shapes",
]
fx_compatible = False
test_head_masking = False
test_pruning = False
@@ -796,21 +627,6 @@ class Siglip2ModelTest(Siglip2ModelTesterMixin, PipelineTesterMixin, unittest.Te
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest("Siglip2 does not support right padding")
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@is_flaky()
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
super().test_eager_matches_sdpa_inference(
torch_dtype=torch_dtype,
logit_keys=("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"),
use_attention_mask_options=(False, True),
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
super().test_sdpa_can_dispatch_composite_models()
class Siglip2ForImageClassificationModelTester(Siglip2ModelTester):
def __init__(self, parent):
@@ -841,6 +657,7 @@ class Siglip2ForImageClassificationModelTester(Siglip2ModelTester):
class Siglip2ForImageClassificationModelTest(Siglip2ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Siglip2ForImageClassification,) if is_torch_available() else ()
pipeline_model_mapping = {"image-classification": Siglip2ForImageClassification} if is_torch_available() else {}
additional_model_inputs = ["pixel_values", "pixel_attention_mask", "spatial_shapes"]
fx_compatible = False
test_head_masking = False
test_pruning = False
@@ -881,19 +698,6 @@ class Siglip2ForImageClassificationModelTest(Siglip2ModelTesterMixin, PipelineTe
def test_initialization(self):
pass
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@is_flaky()
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
super().test_eager_matches_sdpa_inference(
torch_dtype=torch_dtype, logit_keys=("logits",), use_attention_mask_options=(False,)
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
super().test_sdpa_can_dispatch_composite_models()
# Draw a circle on an images with different aspect ratios
def prepare_images():

View File

@@ -3457,7 +3457,7 @@ class ModelTesterMixin:
):
# TODO: we shouldn't need to do this skip, i.e. the test would be composable from the model tester. CLIP-like
# models have a custom mixin, which we detect to skip this test.
if not any(".ModelTesterMixin" in str(base) for base in self.__class__.__bases__):
if any(".CLIPModelTesterMixin" in str(base) for base in self.__class__.__bases__):
self.skipTest(reason="CLIP-like models have a different `test_eager_matches_sdpa_inference`")
if not self.has_attentions:
@@ -3549,206 +3549,213 @@ class ModelTesterMixin:
model_eager = model_class.from_pretrained(**model_from_pretrained_kwargs, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype)
set_model_for_less_flaky_test(model_eager)
set_model_for_less_flaky_test(model_sdpa)
set_model_for_less_flaky_test(model_eager)
set_model_for_less_flaky_test(model_sdpa)
can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
if not (self.has_attentions and can_output_attn) and output_attentions:
self.skipTest(reason="Model does not support output_attentions")
can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
if not (self.has_attentions and can_output_attn) and output_attentions:
self.skipTest(reason="Model does not support output_attentions")
# TODO: if we can also check with `batch_size=1` without being flaky?
for batch_size in [7]:
# musicgen decoder models; TODO: find better abstraction
if hasattr(self.model_tester, "num_codebooks") and not hasattr(model_eager, "text_encoder"):
# TODO: if we can also check with `batch_size=1` without being flaky?
for batch_size in [7]:
# musicgen decoder models; TODO: find better abstraction
if hasattr(self.model_tester, "num_codebooks") and not hasattr(model_eager, "text_encoder"):
input_data_batch_size = batch_size * self.model_tester.num_codebooks
else:
input_data_batch_size = batch_size
processed_inputs = {}
processed_inputs[model.main_input_name] = inputs_dict[model.main_input_name]
for key in getattr(self, "additional_model_inputs", []):
processed_inputs[key] = inputs_dict[key]
for key, value in processed_inputs.items():
if torch.is_floating_point(value):
value = value.to(torch_dtype)
# extend value to have at least `input_data_batch_size` elements
if value.shape[0] < input_data_batch_size:
size = (input_data_batch_size - value.shape[0], *value.shape[1:])
if torch.is_floating_point(value):
extension = torch.rand(size=size, dtype=value.dtype, device=torch_device)
else:
extension = torch.randint(high=5, size=size, dtype=value.dtype, device=torch_device)
value = torch.cat((value, extension), dim=0).to(torch_device)
processed_inputs[key] = value[:input_data_batch_size]
if not use_attention_mask:
dummy_attention_mask = None
else:
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is None:
if is_encoder_decoder:
seqlen = inputs_dict.get(
"decoder_input_ids", processed_inputs[model.main_input_name]
).shape[-1]
else:
seqlen = processed_inputs[model.main_input_name].shape[-1]
dummy_attention_mask = torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
# extend dummy_attention_mask to have at least `batch_size` elements
if dummy_attention_mask.shape[0] < batch_size:
size = (batch_size - dummy_attention_mask.shape[0], *dummy_attention_mask.shape[1:])
extension = torch.ones(size=size, dtype=dummy_attention_mask.dtype, device=torch_device)
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
dummy_attention_mask = dummy_attention_mask[:batch_size].to(torch_device)
dummy_attention_mask[:] = 1
if padding_side == "left":
dummy_attention_mask[-1, :2] = 0
dummy_attention_mask[-1, 2:] = 1
elif padding_side == "right":
dummy_attention_mask[-1, -2:] = 0
dummy_attention_mask[-1, :-2] = 1
if is_encoder_decoder:
# musicgen encoder-decoder models; TODO: find better abstraction
if hasattr(self.model_tester, "num_codebooks"):
input_data_batch_size = batch_size * self.model_tester.num_codebooks
else:
input_data_batch_size = batch_size
dummy_input = inputs_dict[model.main_input_name]
decoder_input_ids = inputs_dict.get("decoder_input_ids", processed_inputs[model.main_input_name])
decoder_input_ids = decoder_input_ids[:input_data_batch_size]
if decoder_input_ids.shape[0] != input_data_batch_size:
extension = torch.ones(
input_data_batch_size - decoder_input_ids.shape[0],
*decoder_input_ids.shape[1:],
dtype=decoder_input_ids.dtype,
device=torch_device,
)
decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0)
decoder_input_ids = decoder_input_ids.to(torch_device)
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
dummy_input = dummy_input.to(torch_dtype)
dummy_input = dummy_input[:input_data_batch_size]
if dummy_input.shape[0] != input_data_batch_size:
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
extension = torch.rand(
input_data_batch_size - dummy_input.shape[0],
*dummy_input.shape[1:],
dtype=torch_dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
else:
extension = torch.randint(
high=5,
size=(input_data_batch_size - dummy_input.shape[0], *dummy_input.shape[1:]),
dtype=dummy_input.dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
if not use_attention_mask:
dummy_attention_mask = None
else:
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is None:
if is_encoder_decoder:
seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
else:
seqlen = dummy_input.shape[-1]
dummy_attention_mask = torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
dummy_attention_mask = dummy_attention_mask[:batch_size]
if dummy_attention_mask.shape[0] != batch_size:
extension = torch.ones(
batch_size - dummy_attention_mask.shape[0],
*dummy_attention_mask.shape[1:],
dtype=dummy_attention_mask.dtype,
device=torch_device,
)
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
dummy_attention_mask = dummy_attention_mask.to(torch_device)
dummy_attention_mask[:] = 1
if padding_side == "left":
dummy_attention_mask[-1, :2] = 0
dummy_attention_mask[-1, 2:] = 1
elif padding_side == "right":
dummy_attention_mask[-1, -2:] = 0
dummy_attention_mask[-1, :-2] = 1
if is_encoder_decoder:
# musicgen encoder-decoder models; TODO: find better abstraction
if hasattr(self.model_tester, "num_codebooks"):
input_data_batch_size = batch_size * self.model_tester.num_codebooks
else:
input_data_batch_size = batch_size
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:input_data_batch_size]
if decoder_input_ids.shape[0] != input_data_batch_size:
extension = torch.ones(
input_data_batch_size - decoder_input_ids.shape[0],
*decoder_input_ids.shape[1:],
dtype=decoder_input_ids.dtype,
device=torch_device,
)
decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0)
decoder_input_ids = decoder_input_ids.to(torch_device)
# TODO: never an `attention_mask` arg here?
processed_inputs = {
model.main_input_name: dummy_input,
# TODO: never an `attention_mask` arg here?
processed_inputs.update(
{
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
else:
processed_inputs = {
model.main_input_name: dummy_input,
)
else:
processed_inputs.update(
{
"output_hidden_states": True,
}
)
# Otherwise fails for e.g. WhisperEncoderModel
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
processed_inputs["attention_mask"] = dummy_attention_mask
# Otherwise fails for e.g. WhisperEncoderModel
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
processed_inputs["attention_mask"] = dummy_attention_mask
if (
self.has_attentions
and "output_attentions" in inspect.signature(model_sdpa.forward).parameters
):
processed_inputs["output_attentions"] = output_attentions
if "bool_masked_pos" in inspect.signature(model_eager.forward).parameters:
dummy_mask = torch.ones((self.model_tester.num_masks,))
if self.has_attentions and "output_attentions" in inspect.signature(model_sdpa.forward).parameters:
processed_inputs["output_attentions"] = output_attentions
if "bool_masked_pos" in inspect.signature(model_eager.forward).parameters:
dummy_mask = torch.ones((self.model_tester.num_masks,))
# In case of additional token (like class) we define a custom `mask_length`
if hasattr(self.model_tester, "mask_length"):
mask_length = self.model_tester.mask_length - dummy_mask.size(0)
else:
mask_length = self.model_tester.seq_length - dummy_mask.size(0)
dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)
if "noise" in inspect.signature(model_eager.forward).parameters:
np.random.seed(2)
num_patches = int((self.model_tester.image_size // self.model_tester.patch_size) ** 2)
noise = np.random.uniform(size=(batch_size, num_patches))
processed_inputs["noise"] = torch.from_numpy(noise)
# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
# TODO: rename logits -> hidden_states
if hasattr(outputs_eager, "vision_hidden_states"):
logits_eager = outputs_eager.vision_hidden_states[-1]
logits_sdpa = outputs_sdpa.vision_hidden_states[-1]
elif hasattr(outputs_eager, "audio_values"):
logits_eager = outputs_eager.audio_values
logits_sdpa = outputs_sdpa.audio_values
# In case of additional token (like class) we define a custom `mask_length`
if hasattr(self.model_tester, "mask_length"):
mask_length = self.model_tester.mask_length - dummy_mask.size(0)
else:
logits_eager = (
outputs_eager.decoder_hidden_states[-1]
if hasattr(outputs_eager, "decoder_hidden_states")
else outputs_eager.hidden_states[-1]
)
logits_sdpa = (
outputs_sdpa.decoder_hidden_states[-1]
if hasattr(outputs_sdpa, "decoder_hidden_states")
else outputs_sdpa.hidden_states[-1]
)
mask_length = self.model_tester.seq_length - dummy_mask.size(0)
dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
if "noise" in inspect.signature(model_eager.forward).parameters:
np.random.seed(2)
num_patches = int((self.model_tester.image_size // self.model_tester.patch_size) ** 2)
noise = np.random.uniform(size=(batch_size, num_patches))
processed_inputs["noise"] = torch.from_numpy(noise)
# Masked tokens output slightly deviates - we don't mind that.
if use_attention_mask:
_logits_sdpa = torch.zeros_like(input=logits_sdpa)
_logits_eager = torch.zeros_like(input=logits_eager)
# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
prepared_inputs = {
k: v.to(torch_device) if isinstance(v, torch.Tensor) else v
for k, v in prepared_inputs.items()
}
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
_logits_sdpa[:-1] = logits_sdpa[:-1]
_logits_eager[:-1] = logits_eager[:-1]
if "logits_per_text" in outputs_eager:
key = "logits_per_text"
elif "vision_hidden_states" in outputs_eager:
key = "vision_hidden_states"
elif "audio_values" in outputs_eager:
key = "audio_values"
elif "decoder_hidden_states" in outputs_eager:
key = "decoder_hidden_states"
elif "logits" in outputs_eager and "Classification" in model_class.__name__:
key = "logits"
else:
key = "hidden_states"
if padding_side == "left":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
_logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
# TODO: rename logits -> hidden_states
logits_eager = outputs_eager[key]
logits_sdpa = outputs_sdpa[key]
elif padding_side == "right":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
_logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
if key in ["vision_hidden_states", "decoder_hidden_states", "hidden_states"]:
logits_eager = logits_eager[-1]
logits_sdpa = logits_sdpa[-1]
logits_sdpa = _logits_sdpa
logits_eager = _logits_eager
if key == "logits_per_text":
nan_mask = torch.isnan(logits_eager)
logits_eager[nan_mask] = 0
logits_sdpa[nan_mask] = 0
results = [
torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
]
# If 80% batch elements have matched results, it's fine
if np.mean(results) < 0.8:
mean_relative_diff = ((logits_sdpa - logits_eager).abs() / (logits_eager.abs() + 1e-12)).mean()
raise ValueError(
f"mean relative difference: {mean_relative_diff:.3e}, torch atol = {atol}, torch rtol = "
f"{rtol}"
)
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
# Masked tokens output slightly deviates - we don't mind that.
if use_attention_mask:
_logits_sdpa = torch.zeros_like(input=logits_sdpa)
_logits_eager = torch.zeros_like(input=logits_eager)
_logits_sdpa[:-1] = logits_sdpa[:-1]
_logits_eager[:-1] = logits_eager[:-1]
if padding_side == "left":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
_logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
elif padding_side == "right":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
_logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
logits_sdpa = _logits_sdpa
logits_eager = _logits_eager
results = [
torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
]
# If 80% batch elements have matched results, it's fine
if np.mean(results) < 0.8:
mean_relative_diff = ((logits_sdpa - logits_eager).abs() / (logits_eager.abs() + 1e-12)).mean()
raise ValueError(
f"mean relative difference for {key}: {mean_relative_diff:.3e}, torch atol = {atol}, torch rtol = "
f"{rtol}"
)
@require_torch_sdpa
@require_torch_gpu