From a177821b24565b8767e8b0e52a78569caade3040 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 8 Jul 2024 11:10:02 +0100 Subject: [PATCH] Add FA2 and `sdpa` support for SigLIP (#31499) * Rebase to main * Fix attention implementation autoset for tex and vision configs * Fixup * Minor fixes * Fix copies * Fix attention_mask for FA2 * Add eqvivalence tests for siglip * Remove right padding test * Uncomment flaky * Fix import * Add to docs * Fix test message * Add sdpa * Add sdpa equivalence test * Add siglip sdpa to docs * Fix typing for attention output * Add sdpa tests * Fix signature of FA2 * Autoset attn_implementation in config * Rename bsz -> batch_size * Move back autoset attn method * Mark as flaky * Correct attention mask padding * [run-slow] siglip * Add FA2 and sdpa docs * Style fix * Remove flaky for FA2 test * Change attention implementation set * Change attn_implementaiton propogation * Fix typos * Add modality to assert message * Add more sdpa backends in test * [run slow] siglip * Add math sdpa backend for all options * [run slow] siglip --- docs/source/en/model_doc/siglip.md | 82 +++++ docs/source/en/perf_infer_gpu_one.md | 2 + .../models/idefics2/modeling_idefics2.py | 2 +- .../models/siglip/modeling_siglip.py | 308 +++++++++++++++++- tests/models/siglip/test_modeling_siglip.py | 299 ++++++++++++++++- 5 files changed, 680 insertions(+), 13 deletions(-) diff --git a/docs/source/en/model_doc/siglip.md b/docs/source/en/model_doc/siglip.md index 2795cfb711..4f46174fb1 100644 --- a/docs/source/en/model_doc/siglip.md +++ b/docs/source/en/model_doc/siglip.md @@ -107,6 +107,88 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource. + +## Combining SigLIP and Flash Attention 2 + +First, make sure to install the latest version of Flash Attention 2. + +```bash +pip install -U flash-attn --no-build-isolation +``` + +Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``) + +To load and run a model using Flash Attention 2, refer to the snippet below: + +```python +>>> import torch +>>> import requests +>>> from PIL import Image +>>> from transformers import SiglipProcessor, SiglipModel +>>> device = "cuda" # the device to load the model onto + +>>> model = SiglipModel.from_pretrained( +... "google/siglip-so400m-patch14-384", +... attn_implementation="flash_attention_2", +... torch_dtype=torch.float16, +... device_map=device, +... ) +>>> processor = SiglipProcessor.from_pretrained("google/siglip-so400m-patch14-384") + +>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" +>>> image = Image.open(requests.get(url, stream=True).raw) + +>>> candidate_labels = ["2 cats", "2 dogs"] +# follows the pipeline prompt template to get same results +>>> candidate_labels = [f'This is a photo of {label}.' for label in candidate_labels] +# important: we pass `padding=max_length` since the model was trained with this +>>> inputs = processor(text=candidate_labels, images=image, padding="max_length", return_tensors="pt") +>>> inputs.to(device) + +>>> with torch.no_grad(): +... with torch.autocast(device): +... outputs = model(**inputs) + +>>> logits_per_image = outputs.logits_per_image +>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities +>>> print(f"{probs[0][0]:.1%} that image 0 is '{candidate_labels[0]}'") +51.3% that image 0 is 'This is a photo of 2 cats.' +``` + + +## Using Scaled Dot Product Attention (SDPA) + +PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function +encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the +[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) +or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) +page for more information. + +You may set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. Make sure you have `torch>=2.1.1`. + +```python +>>> from transformers import SiglipModel + +>>> model = SiglipModel.from_pretrained( +... "google/siglip-so400m-patch14-384", +... attn_implementation="sdpa", +... torch_dtype=torch.float16, +... device_map=device, +... ) +``` + +For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`). + + +## Expected speedups + +Below is an expected speedup diagram that compares inference time between the native implementation in transformers using `google/siglip-so400m-patch14-384` checkpoint in `float16` precision and the Flash Attention 2 / SDPA version of the model using different batch sizes. + +
+ +
+ + ## SiglipConfig [[autodoc]] SiglipConfig diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 1569bef1f6..b18e737ff9 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -70,6 +70,7 @@ FlashAttention-2 is currently supported for the following architectures: * [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel) * [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel) * [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model) +* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) * [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) * [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model) * [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model) @@ -231,6 +232,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model) * [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel) * [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel) +* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) * [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel) * [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) * [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index aacd0fdad1..2857e366d2 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -221,7 +221,7 @@ class Idefics2VisionAttention(nn.Module): hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" batch_size, q_len, _ = hidden_states.size() diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index 068f2173d7..50d41ef509 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -21,6 +21,7 @@ from typing import Any, Optional, Tuple, Union import numpy as np import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -34,12 +35,19 @@ from ...utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + logger = logging.get_logger(__name__) # General docstring @@ -47,6 +55,19 @@ _CONFIG_FOR_DOC = "SiglipConfig" _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf @@ -373,7 +394,7 @@ class SiglipAttention(nn.Module): hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" batch_size, q_len, _ = hidden_states.size() @@ -421,6 +442,266 @@ class SiglipAttention(nn.Module): return attn_output, attn_weights +class SiglipFlashAttention2(SiglipAttention): + """ + SiglipAttention flash attention module. This module inherits from `SiglipAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + is_causal = False + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class SiglipSdpaAttention(SiglipAttention): + """ + Siglip attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `SiglipAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + is_causal = False + + # Adapted from SiglipAttention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "SiglipModel is using SiglipSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if self.is_causal and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None + + +SIGLIP_ATTENTION_CLASSES = { + "eager": SiglipAttention, + "flash_attention_2": SiglipFlashAttention2, + "sdpa": SiglipSdpaAttention, +} + + # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip class SiglipMLP(nn.Module): def __init__(self, config): @@ -437,12 +718,11 @@ class SiglipMLP(nn.Module): return hidden_states -# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip class SiglipEncoderLayer(nn.Module): def __init__(self, config: SiglipConfig): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = SiglipAttention(config) + self.self_attn = SIGLIP_ATTENTION_CLASSES[config._attn_implementation](config=config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -503,6 +783,8 @@ class SiglipPreTrainedModel(PreTrainedModel): "SiglipEncoderLayer", "SiglipMultiheadAttentionPoolingHead", ] + _supports_flash_attn_2 = True + _supports_sdpa = True def _init_weights(self, module): """Initialize the weights""" @@ -754,6 +1036,7 @@ class SiglipTextTransformer(nn.Module): self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.head = nn.Linear(embed_dim, embed_dim) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) @@ -786,7 +1069,7 @@ class SiglipTextTransformer(nn.Module): # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. # expand attention_mask - if attention_mask is not None: + if attention_mask is not None and not self._use_flash_attention_2: # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) @@ -1041,8 +1324,13 @@ class SiglipModel(SiglipPreTrainedModel): text_config = config.text_config vision_config = config.vision_config - self.text_model = SiglipTextTransformer(text_config) - self.vision_model = SiglipVisionTransformer(vision_config) + # First, initialize the text and vision models with proper attention implementation + text_model = SiglipTextModel._from_config(text_config, attn_implementation=config._attn_implementation) + vision_model = SiglipVisionModel._from_config(vision_config, attn_implementation=config._attn_implementation) + + # Second, get the text and vision submodules (for backward compatibility) + self.text_model = text_model.text_model + self.vision_model = vision_model.vision_model self.logit_scale = nn.Parameter(torch.randn(1)) self.logit_bias = nn.Parameter(torch.randn(1)) @@ -1270,7 +1558,13 @@ class SiglipForImageClassification(SiglipPreTrainedModel): super().__init__(config) self.num_labels = config.num_labels - self.vision_model = SiglipVisionTransformer(config.vision_config) + + # Create the vision model with proper attention + # and take only vision_model submodule (for backward compatibility) + vision_model = SiglipVisionModel._from_config( + config.vision_config, attn_implementation=config._attn_implementation + ) + self.vision_model = vision_model.vision_model # Classifier head self.classifier = ( diff --git a/tests/models/siglip/test_modeling_siglip.py b/tests/models/siglip/test_modeling_siglip.py index 8bdc995e51..9d1e3109b3 100644 --- a/tests/models/siglip/test_modeling_siglip.py +++ b/tests/models/siglip/test_modeling_siglip.py @@ -18,18 +18,30 @@ import inspect import os import tempfile import unittest +from typing import Tuple import numpy as np import requests +from parameterized import parameterized +from pytest import mark from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig from transformers.testing_utils import ( + require_flash_attn, require_torch, + require_torch_gpu, + require_torch_sdpa, require_vision, slow, torch_device, ) -from transformers.utils import is_torch_available, is_vision_available +from transformers.utils import ( + is_torch_available, + is_torch_bf16_available_on_device, + is_torch_fp16_available_on_device, + is_torch_sdpa_available, + is_vision_available, +) from ...test_configuration_common import ConfigTester from ...test_modeling_common import ( @@ -37,6 +49,7 @@ from ...test_modeling_common import ( _config_zero_init, floats_tensor, ids_tensor, + is_flaky, random_attention_mask, ) from ...test_pipeline_mixin import PipelineTesterMixin @@ -48,6 +61,8 @@ if is_torch_available(): from transformers import SiglipForImageClassification, SiglipModel, SiglipTextModel, SiglipVisionModel +if is_torch_sdpa_available(): + from torch.nn.attention import SDPBackend, sdpa_kernel if is_vision_available(): from PIL import Image @@ -55,6 +70,155 @@ if is_vision_available(): from transformers import SiglipProcessor +class SiglipModelTesterMixin(ModelTesterMixin): + def test_eager_matches_sdpa_inference( + self, + torch_dtype: str, + use_attention_mask_options: Tuple[bool, ...] = (True, False), + logit_keys: Tuple[str, ...] = ("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"), + ): + if not self.all_model_classes[0]._supports_sdpa: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device): + self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") + + if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device): + self.skipTest( + f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)" + ) + + # Convert to torch dtype + dtypes = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + } + torch_dtype = dtypes[torch_dtype] + + atols = { + torch.float32: 1e-5, + torch.bfloat16: 3e-2, + torch.float16: 5e-3, + } + rtols = { + torch.float32: 1e-4, + torch.bfloat16: 3e-2, + torch.float16: 5e-3, + } + + atol = atols[torch_dtype] + rtol = rtols[torch_dtype] + + def get_mean_reldiff(msg, current_case, x, ref, atol, rtol): + return f"{msg} {current_case}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}" + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + # Load the model with SDPA + model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) + model_sdpa = model_sdpa.eval().to(torch_device) + + # Load model with eager attention + model_eager = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch_dtype, + attn_implementation="eager", + ) + model_eager = model_eager.eval().to(torch_device) + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + if not has_sdpa and model_sdpa.config.model_type != "falcon": + raise ValueError("The SDPA model should have SDPA attention layers") + + # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving the model each time, + # but it would be nicer to have an efficient way to use parameterized.expand + cases = [ + (use_mask, output_attentions, sdpa_backend, batch_size) + for use_mask in use_attention_mask_options + for output_attentions in [True, False] + for sdpa_backend in [ + SDPBackend.MATH, + [SDPBackend.FLASH_ATTENTION, SDPBackend.MATH], + [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH], + [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH], + ] + for batch_size in [1, 5] + ] + fail_cases = [] + + for use_mask, output_attentions, sdpa_backend, batch_size in cases: + processed_inputs = inputs_dict.copy() + + # convert to torch_dtype + if "pixel_values" in processed_inputs: + processed_inputs["pixel_values"] = processed_inputs["pixel_values"].to(torch_dtype) + + # slice for different batch sizes + for key in ["pixel_values", "input_ids", "attention_mask"]: + if key in processed_inputs: + processed_inputs[key] = processed_inputs[key][:batch_size] + + # set attention mask with left padding + if not use_mask: + processed_inputs.pop("attention_mask", None) + else: + dummy_attention_mask = processed_inputs["attention_mask"] + dummy_attention_mask[:] = 1 + dummy_attention_mask[:, :1] = 0 + processed_inputs["attention_mask"] = dummy_attention_mask + + processed_inputs["output_attentions"] = output_attentions + processed_inputs["output_hidden_states"] = True + + current_case = ( + f"padding_side=left, use_mask={use_mask}, batch_size={batch_size}, sdpa_backend={sdpa_backend}" + ) + + prepared_inputs = self._prepare_for_class(processed_inputs, model_class) + + with torch.no_grad(): + try: + with sdpa_kernel(sdpa_backend): + outputs_eager = model_eager(**prepared_inputs) + outputs_sdpa = model_sdpa(**prepared_inputs) + except Exception as e: + fail_cases.append(f"{current_case}: {e}") + continue + + for key in logit_keys: + eager_logits = outputs_eager[key] + sdpa_logits = outputs_sdpa[key] + + if use_mask: + eager_logits = eager_logits[:, 1:] + sdpa_logits = sdpa_logits[:, 1:] + + is_close = torch.allclose(eager_logits, sdpa_logits, atol=atol, rtol=rtol) + if not is_close: + fail_cases.append(get_mean_reldiff(key, current_case, sdpa_logits, eager_logits, atol, rtol)) + + self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) + + class SiglipVisionModelTester: def __init__( self, @@ -135,7 +299,7 @@ class SiglipVisionModelTester: @require_torch -class SiglipVisionModelTest(ModelTesterMixin, unittest.TestCase): +class SiglipVisionModelTest(SiglipModelTesterMixin, unittest.TestCase): """ Here we also overwrite some of the tests of test_modeling_common.py, as SIGLIP does not use input_ids, inputs_embeds, attention_mask and seq_length. @@ -225,6 +389,17 @@ class SiglipVisionModelTest(ModelTesterMixin, unittest.TestCase): model = SiglipVisionModel.from_pretrained(model_name) self.assertIsNotNone(model) + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @require_torch_sdpa + @slow + @is_flaky() + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + super().test_eager_matches_sdpa_inference( + torch_dtype=torch_dtype, + logit_keys=("pooler_output", "last_hidden_state"), + use_attention_mask_options=(False,), + ) + class SiglipTextModelTester: def __init__( @@ -314,7 +489,7 @@ class SiglipTextModelTester: @require_torch -class SiglipTextModelTest(ModelTesterMixin, unittest.TestCase): +class SiglipTextModelTest(SiglipModelTesterMixin, unittest.TestCase): all_model_classes = (SiglipTextModel,) if is_torch_available() else () fx_compatible = False test_pruning = False @@ -376,6 +551,17 @@ class SiglipTextModelTest(ModelTesterMixin, unittest.TestCase): model = SiglipTextModel.from_pretrained(model_name) self.assertIsNotNone(model) + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @require_torch_sdpa + @slow + @is_flaky() + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + super().test_eager_matches_sdpa_inference( + torch_dtype=torch_dtype, + logit_keys=("pooler_output", "last_hidden_state"), + use_attention_mask_options=(False, True), + ) + class SiglipModelTester: def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True): @@ -429,7 +615,7 @@ class SiglipModelTester: @require_torch -class SiglipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): +class SiglipModelTest(SiglipModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (SiglipModel,) if is_torch_available() else () pipeline_model_mapping = {"feature-extraction": SiglipModel} if is_torch_available() else {} fx_compatible = False @@ -571,6 +757,100 @@ class SiglipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): model = SiglipModel.from_pretrained(model_name) self.assertIsNotNone(model) + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + dummy_pixel_values = inputs_dict["pixel_values"].to(torch.bfloat16) + dummy_input_ids = inputs_dict["input_ids"] + + outputs = model(pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True) + outputs_fa = model_fa( + pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True + ) + + self.assertTrue( + torch.allclose(outputs.logits_per_image, outputs_fa.logits_per_image, atol=4e-2, rtol=4e-2), + f"Image logits max diff: {torch.max(torch.abs(outputs.logits_per_image - outputs_fa.logits_per_image))}", + ) + self.assertTrue( + torch.allclose(outputs.logits_per_text, outputs_fa.logits_per_text, atol=4e-2, rtol=4e-2), + f"Text logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}", + ) + + # Test with attention mask + dummy_attention_mask = inputs_dict["attention_mask"] + + if dummy_attention_mask is not None: + dummy_attention_mask[:, 1:] = 1 + dummy_attention_mask[:, :1] = 0 + + outputs = model( + pixel_values=dummy_pixel_values, + input_ids=dummy_input_ids, + attention_mask=dummy_attention_mask, + output_hidden_states=True, + ) + outputs_fa = model_fa( + pixel_values=dummy_pixel_values, + input_ids=dummy_input_ids, + attention_mask=dummy_attention_mask, + output_hidden_states=True, + ) + + self.assertTrue( + torch.allclose(outputs.logits_per_image, outputs_fa.logits_per_image, atol=4e-2, rtol=4e-2), + f"Logits max diff: {torch.max(torch.abs(outputs.logits_per_image - outputs_fa.logits_per_image))}", + ) + self.assertTrue( + torch.allclose(outputs.logits_per_text, outputs_fa.logits_per_text, atol=4e-2, rtol=4e-2), + f"Logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}", + ) + + # check with inference + dropout + model.train() + _ = model_fa( + pixel_values=dummy_pixel_values, + input_ids=dummy_input_ids, + attention_mask=dummy_attention_mask, + output_hidden_states=True, + ) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + def test_flash_attn_2_inference_equivalence_right_padding(self): + self.skipTest("SigLIP does not support right padding") + + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @require_torch_sdpa + @slow + @is_flaky() + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + super().test_eager_matches_sdpa_inference( + torch_dtype=torch_dtype, + logit_keys=("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"), + use_attention_mask_options=(False, True), + ) + class SiglipForImageClassificationModelTester(SiglipModelTester): def __init__(self, parent): @@ -594,7 +874,7 @@ class SiglipForImageClassificationModelTester(SiglipModelTester): @require_torch -class SiglipForImageClassificationModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): +class SiglipForImageClassificationModelTest(SiglipModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (SiglipForImageClassification,) if is_torch_available() else () pipeline_model_mapping = {"image-classification": SiglipForImageClassification} if is_torch_available() else {} fx_compatible = False @@ -636,6 +916,15 @@ class SiglipForImageClassificationModelTest(ModelTesterMixin, PipelineTesterMixi def test_initialization(self): pass + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @require_torch_sdpa + @slow + @is_flaky() + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + super().test_eager_matches_sdpa_inference( + torch_dtype=torch_dtype, logit_keys=("logits",), use_attention_mask_options=(False,) + ) + # We will verify our results on an image of cute cats def prepare_img():