From 66291778dd7cea6622219257bf890b20835a6de9 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Thu, 20 Mar 2025 15:15:01 +0000 Subject: [PATCH] Refactor Attention implementation for ViT-based models (#36545) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Refactor vit attention * Refactor ViT-based models * 🚨🚨🚨 Fix prefix for DPT * Update params order * trigger tests * Fix Dinov2 attention * Fix DPT attention impl propagation for backbone config * Common test fix: config is modif. inplace - avoid it * view->reshape * Fixup * Fixup * Enable IJepa FA2 * Add FA2 in corresponding model docs --- .../audio-spectrogram-transformer.md | 1 + docs/source/en/model_doc/deit.md | 1 + docs/source/en/model_doc/dinov2.md | 1 + .../en/model_doc/dinov2_with_registers.md | 1 + docs/source/en/model_doc/dpt.md | 2 + docs/source/en/model_doc/ijepa.md | 1 + docs/source/en/model_doc/videomae.md | 1 + docs/source/en/model_doc/vit.md | 1 + docs/source/en/model_doc/vit_mae.md | 1 + docs/source/en/model_doc/vit_msn.md | 1 + docs/source/en/model_doc/vivit.md | 1 + docs/source/en/model_doc/yolos.md | 1 + src/transformers/modeling_utils.py | 4 +- .../modeling_audio_spectrogram_transformer.py | 144 +++++++----------- src/transformers/models/deit/modeling_deit.py | 144 +++++++----------- .../depth_anything/modeling_depth_anything.py | 3 +- .../models/dinov2/modeling_dinov2.py | 137 +++++++---------- .../modeling_dinov2_with_registers.py | 135 +++++++--------- .../models/dpt/configuration_dpt.py | 4 + src/transformers/models/dpt/modeling_dpt.py | 90 +++++++---- .../models/ijepa/modeling_ijepa.py | 142 +++++++---------- .../models/ijepa/modular_ijepa.py | 1 + .../models/videomae/modeling_videomae.py | 124 +++++++-------- src/transformers/models/vit/modeling_vit.py | 140 +++++++---------- .../models/vit_mae/modeling_vit_mae.py | 144 +++++++----------- .../models/vit_msn/modeling_vit_msn.py | 141 +++++++---------- .../modeling_vitpose_backbone.py | 86 +++++++---- .../models/vivit/modeling_vivit.py | 141 +++++++---------- .../models/yolos/modeling_yolos.py | 141 +++++++---------- .../models/zoedepth/modeling_zoedepth.py | 3 +- tests/models/dpt/test_modeling_dpt.py | 4 + .../models/videomae/test_modeling_videomae.py | 65 +++++++- tests/models/vit_mae/test_modeling_vit_mae.py | 68 ++++++++- tests/test_configuration_common.py | 2 +- tests/test_modeling_common.py | 31 ++-- 35 files changed, 932 insertions(+), 975 deletions(-) diff --git a/docs/source/en/model_doc/audio-spectrogram-transformer.md b/docs/source/en/model_doc/audio-spectrogram-transformer.md index 4cc07aea75..14669ce0fb 100644 --- a/docs/source/en/model_doc/audio-spectrogram-transformer.md +++ b/docs/source/en/model_doc/audio-spectrogram-transformer.md @@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
PyTorch +FlashAttention SDPA
diff --git a/docs/source/en/model_doc/deit.md b/docs/source/en/model_doc/deit.md index 0750d4000a..57cfee1f11 100644 --- a/docs/source/en/model_doc/deit.md +++ b/docs/source/en/model_doc/deit.md @@ -19,6 +19,7 @@ rendered properly in your Markdown viewer.
PyTorch TensorFlow +FlashAttention SDPA
diff --git a/docs/source/en/model_doc/dinov2.md b/docs/source/en/model_doc/dinov2.md index 5c130dabda..acf7b20600 100644 --- a/docs/source/en/model_doc/dinov2.md +++ b/docs/source/en/model_doc/dinov2.md @@ -16,6 +16,7 @@ specific language governing permissions and limitations under the License. PyTorch Flax +FlashAttention SDPA diff --git a/docs/source/en/model_doc/dinov2_with_registers.md b/docs/source/en/model_doc/dinov2_with_registers.md index 7151dc4535..3b12d314a5 100644 --- a/docs/source/en/model_doc/dinov2_with_registers.md +++ b/docs/source/en/model_doc/dinov2_with_registers.md @@ -11,6 +11,7 @@ specific language governing permissions and limitations under the License.
PyTorch +FlashAttention SDPA
diff --git a/docs/source/en/model_doc/dpt.md b/docs/source/en/model_doc/dpt.md index 7010d03cdc..95e422dee8 100644 --- a/docs/source/en/model_doc/dpt.md +++ b/docs/source/en/model_doc/dpt.md @@ -18,6 +18,8 @@ rendered properly in your Markdown viewer.
PyTorch +FlashAttention +SDPA
## Overview diff --git a/docs/source/en/model_doc/ijepa.md b/docs/source/en/model_doc/ijepa.md index a92fdc83e8..d350114781 100644 --- a/docs/source/en/model_doc/ijepa.md +++ b/docs/source/en/model_doc/ijepa.md @@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
PyTorch +FlashAttention SDPA
diff --git a/docs/source/en/model_doc/videomae.md b/docs/source/en/model_doc/videomae.md index f115d81694..be048d5b73 100644 --- a/docs/source/en/model_doc/videomae.md +++ b/docs/source/en/model_doc/videomae.md @@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
PyTorch +FlashAttention SDPA
diff --git a/docs/source/en/model_doc/vit.md b/docs/source/en/model_doc/vit.md index 49c5c0e278..05c724ff7b 100644 --- a/docs/source/en/model_doc/vit.md +++ b/docs/source/en/model_doc/vit.md @@ -21,6 +21,7 @@ rendered properly in your Markdown viewer. TensorFlow Flax +FlashAttention SDPA diff --git a/docs/source/en/model_doc/vit_mae.md b/docs/source/en/model_doc/vit_mae.md index 6ab5096172..893490cf01 100644 --- a/docs/source/en/model_doc/vit_mae.md +++ b/docs/source/en/model_doc/vit_mae.md @@ -19,6 +19,7 @@ rendered properly in your Markdown viewer.
PyTorch TensorFlow +FlashAttention SDPA
diff --git a/docs/source/en/model_doc/vit_msn.md b/docs/source/en/model_doc/vit_msn.md index 53cef45011..a3aadef0e9 100644 --- a/docs/source/en/model_doc/vit_msn.md +++ b/docs/source/en/model_doc/vit_msn.md @@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
PyTorch +FlashAttention SDPA
diff --git a/docs/source/en/model_doc/vivit.md b/docs/source/en/model_doc/vivit.md index 9c4b8f5f71..a2cba9793e 100644 --- a/docs/source/en/model_doc/vivit.md +++ b/docs/source/en/model_doc/vivit.md @@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
PyTorch +FlashAttention SDPA
diff --git a/docs/source/en/model_doc/yolos.md b/docs/source/en/model_doc/yolos.md index a988d0d507..2a0f5d23fa 100644 --- a/docs/source/en/model_doc/yolos.md +++ b/docs/source/en/model_doc/yolos.md @@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
PyTorch +FlashAttention SDPA
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d916e6aaad..4e27421574 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2098,7 +2098,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if not isinstance(requested_attn_implementation, dict) else requested_attn_implementation.get(key, None) ) - sub_config._attn_implementation_internal = curr_attn_implementation + # For models with backbone sub-config might be not initialized + if sub_config is not None: + sub_config._attn_implementation_internal = curr_attn_implementation if use_flash_attention_2: logger.warning_once( diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index 3d72676774..e9e029cf53 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -14,8 +14,7 @@ # limitations under the License. """PyTorch Audio Spectrogram Transformer (AST) model.""" -import math -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Callable, Dict, List, Optional, Set, Tuple, Union import torch import torch.utils.checkpoint @@ -24,7 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_audio_spectrogram_transformer import ASTConfig @@ -108,6 +107,37 @@ class ASTPatchEmbeddings(nn.Module): return embeddings +# Copied from transformers.models.vit.modeling_vit.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->AST class ASTSelfAttention(nn.Module): def __init__(self, config: ASTConfig) -> None: @@ -118,16 +148,18 @@ class ASTSelfAttention(nn.Module): f"heads {config.num_attention_heads}." ) + self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) @@ -136,85 +168,37 @@ class ASTSelfAttention(nn.Module): def forward( self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = self.transpose_for_scores(self.query(hidden_states)) - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - return outputs - - -# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->AST -class ASTSdpaSelfAttention(ASTSelfAttention): - def __init__(self, config: ASTConfig) -> None: - super().__init__(config) - self.attention_probs_dropout_prob = config.attention_probs_dropout_prob - - def forward( - self, - hidden_states: torch.FloatTensor, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - if output_attentions or head_mask is not None: - logger.warning_once( - "`ASTSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " - "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " - "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " - 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - head_mask=head_mask, - output_attentions=output_attentions, - ) - - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) - - context_layer = torch.nn.functional.scaled_dot_product_attention( + context_layer, attention_probs = attention_interface( + self, query_layer, key_layer, value_layer, head_mask, - self.attention_probs_dropout_prob if self.training else 0.0, - is_causal=False, - scale=None, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, ) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) + context_layer = context_layer.reshape(new_context_layer_shape) - return context_layer, None + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->AST @@ -276,13 +260,6 @@ class ASTAttention(nn.Module): return outputs -# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->AST -class ASTSdpaAttention(ASTAttention): - def __init__(self, config: ASTConfig) -> None: - super().__init__(config) - self.attention = ASTSdpaSelfAttention(config) - - # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->AST class ASTIntermediate(nn.Module): def __init__(self, config: ASTConfig) -> None: @@ -316,12 +293,6 @@ class ASTOutput(nn.Module): return hidden_states -AST_ATTENTION_CLASSES = { - "eager": ASTAttention, - "sdpa": ASTSdpaAttention, -} - - # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST,VIT->AST class ASTLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -330,7 +301,7 @@ class ASTLayer(nn.Module): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = AST_ATTENTION_CLASSES[config._attn_implementation](config) + self.attention = ASTAttention(config) self.intermediate = ASTIntermediate(config) self.output = ASTOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -428,6 +399,7 @@ class ASTPreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True _supports_sdpa = True + _supports_flash_attn_2 = True # Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 66a556da81..3c70ae132f 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -15,9 +15,8 @@ """PyTorch DeiT model.""" import collections.abc -import math from dataclasses import dataclass -from typing import Optional, Set, Tuple, Union +from typing import Callable, Optional, Set, Tuple, Union import torch import torch.utils.checkpoint @@ -31,7 +30,7 @@ from ...modeling_outputs import ( ImageClassifierOutput, MaskedImageModelingOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, @@ -180,6 +179,37 @@ class DeiTPatchEmbeddings(nn.Module): return x +# Copied from transformers.models.vit.modeling_vit.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DeiT class DeiTSelfAttention(nn.Module): def __init__(self, config: DeiTConfig) -> None: @@ -190,16 +220,18 @@ class DeiTSelfAttention(nn.Module): f"heads {config.num_attention_heads}." ) + self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) @@ -208,85 +240,37 @@ class DeiTSelfAttention(nn.Module): def forward( self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = self.transpose_for_scores(self.query(hidden_states)) - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - return outputs - - -# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->DeiT -class DeiTSdpaSelfAttention(DeiTSelfAttention): - def __init__(self, config: DeiTConfig) -> None: - super().__init__(config) - self.attention_probs_dropout_prob = config.attention_probs_dropout_prob - - def forward( - self, - hidden_states: torch.FloatTensor, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - if output_attentions or head_mask is not None: - logger.warning_once( - "`DeiTSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " - "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " - "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " - 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - head_mask=head_mask, - output_attentions=output_attentions, - ) - - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) - - context_layer = torch.nn.functional.scaled_dot_product_attention( + context_layer, attention_probs = attention_interface( + self, query_layer, key_layer, value_layer, head_mask, - self.attention_probs_dropout_prob if self.training else 0.0, - is_causal=False, - scale=None, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, ) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) + context_layer = context_layer.reshape(new_context_layer_shape) - return context_layer, None + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->DeiT @@ -348,13 +332,6 @@ class DeiTAttention(nn.Module): return outputs -# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->DeiT -class DeiTSdpaAttention(DeiTAttention): - def __init__(self, config: DeiTConfig) -> None: - super().__init__(config) - self.attention = DeiTSdpaSelfAttention(config) - - # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT class DeiTIntermediate(nn.Module): def __init__(self, config: DeiTConfig) -> None: @@ -388,12 +365,6 @@ class DeiTOutput(nn.Module): return hidden_states -DEIT_ATTENTION_CLASSES = { - "eager": DeiTAttention, - "sdpa": DeiTSdpaAttention, -} - - # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT,VIT->DEIT class DeiTLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -402,7 +373,7 @@ class DeiTLayer(nn.Module): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = DEIT_ATTENTION_CLASSES[config._attn_implementation](config) + self.attention = DeiTAttention(config) self.intermediate = DeiTIntermediate(config) self.output = DeiTOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -501,6 +472,7 @@ class DeiTPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["DeiTLayer"] _supports_sdpa = True + _supports_flash_attn_2 = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/depth_anything/modeling_depth_anything.py b/src/transformers/models/depth_anything/modeling_depth_anything.py index 98a6ccde8c..17b79134ad 100644 --- a/src/transformers/models/depth_anything/modeling_depth_anything.py +++ b/src/transformers/models/depth_anything/modeling_depth_anything.py @@ -240,7 +240,8 @@ class DepthAnythingFeatureFusionStage(nn.Module): return fused_hidden_states -# Copied from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->DepthAnything,dpt->depth_anything +# Modified from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->DepthAnything,dpt->depth_anything +# avoiding sdpa and flash_attn_2 support, it's done in the backend class DepthAnythingPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index 3ba48b7026..2e11d3a76c 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -15,8 +15,7 @@ """PyTorch DINOv2 model.""" import collections.abc -import math -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Callable, Dict, List, Optional, Set, Tuple, Union import torch import torch.utils.checkpoint @@ -30,7 +29,7 @@ from ...modeling_outputs import ( BaseModelOutputWithPooling, ImageClassifierOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( add_code_sample_docstrings, @@ -172,6 +171,37 @@ class Dinov2PatchEmbeddings(nn.Module): return embeddings +# Copied from transformers.models.vit.modeling_vit.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2 class Dinov2SelfAttention(nn.Module): def __init__(self, config: Dinov2Config) -> None: @@ -182,16 +212,18 @@ class Dinov2SelfAttention(nn.Module): f"heads {config.num_attention_heads}." ) + self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) @@ -200,78 +232,37 @@ class Dinov2SelfAttention(nn.Module): def forward( self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = self.transpose_for_scores(self.query(hidden_states)) - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - return outputs - - -class Dinov2SdpaSelfAttention(Dinov2SelfAttention): - def __init__(self, config: Dinov2Config) -> None: - super().__init__(config) - self.attention_probs_dropout_prob = config.attention_probs_dropout_prob - - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False - ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Dinov2Model is using Dinov2SdpaSelfAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, head_mask=head_mask, output_attentions=output_attentions - ) - - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) - - context_layer = torch.nn.functional.scaled_dot_product_attention( + context_layer, attention_probs = attention_interface( + self, query_layer, key_layer, value_layer, head_mask, - self.attention_probs_dropout_prob if self.training else 0.0, - is_causal=False, - scale=None, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, ) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) + context_layer = context_layer.reshape(new_context_layer_shape) - return context_layer, None + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2 @@ -333,13 +324,6 @@ class Dinov2Attention(nn.Module): return outputs -# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Dinov2 -class Dinov2SdpaAttention(Dinov2Attention): - def __init__(self, config: Dinov2Config) -> None: - super().__init__(config) - self.attention = Dinov2SdpaSelfAttention(config) - - class Dinov2LayerScale(nn.Module): def __init__(self, config) -> None: super().__init__() @@ -421,12 +405,6 @@ class Dinov2SwiGLUFFN(nn.Module): return self.weights_out(hidden) -DINOV2_ATTENTION_CLASSES = { - "eager": Dinov2Attention, - "sdpa": Dinov2SdpaAttention, -} - - class Dinov2Layer(nn.Module): """This corresponds to the Block class in the original implementation.""" @@ -434,7 +412,7 @@ class Dinov2Layer(nn.Module): super().__init__() self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = DINOV2_ATTENTION_CLASSES[config._attn_implementation](config) + self.attention = Dinov2Attention(config) self.layer_scale1 = Dinov2LayerScale(config) self.drop_path = Dinov2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() @@ -542,6 +520,7 @@ class Dinov2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Dinov2SwiGLUFFN"] _supports_sdpa = True + _supports_flash_attn_2 = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py index dae5904b78..c7c48dadb7 100644 --- a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +++ b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py @@ -21,8 +21,7 @@ # limitations under the License. import collections.abc -import math -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Callable, Dict, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -30,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( add_code_sample_docstrings, @@ -185,6 +184,36 @@ class Dinov2WithRegistersEmbeddings(nn.Module): return embeddings +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class Dinov2WithRegistersSelfAttention(nn.Module): def __init__(self, config: Dinov2WithRegistersConfig) -> None: super().__init__() @@ -194,16 +223,18 @@ class Dinov2WithRegistersSelfAttention(nn.Module): f"heads {config.num_attention_heads}." ) + self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) @@ -212,78 +243,37 @@ class Dinov2WithRegistersSelfAttention(nn.Module): def forward( self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = self.transpose_for_scores(self.query(hidden_states)) - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - return outputs - - -class Dinov2WithRegistersSdpaSelfAttention(Dinov2WithRegistersSelfAttention): - def __init__(self, config: Dinov2WithRegistersConfig) -> None: - super().__init__(config) - self.attention_probs_dropout_prob = config.attention_probs_dropout_prob - - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False - ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Dinov2WithRegistersModel is using Dinov2WithRegistersSdpaSelfAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, head_mask=head_mask, output_attentions=output_attentions - ) - - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) - - context_layer = torch.nn.functional.scaled_dot_product_attention( + context_layer, attention_probs = attention_interface( + self, query_layer, key_layer, value_layer, head_mask, - self.attention_probs_dropout_prob if self.training else 0.0, - is_causal=False, - scale=None, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, ) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) + context_layer = context_layer.reshape(new_context_layer_shape) - return context_layer, None + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs class Dinov2WithRegistersSelfOutput(nn.Module): @@ -343,12 +333,6 @@ class Dinov2WithRegistersAttention(nn.Module): return outputs -class Dinov2WithRegistersSdpaAttention(Dinov2WithRegistersAttention): - def __init__(self, config: Dinov2WithRegistersConfig) -> None: - super().__init__(config) - self.attention = Dinov2WithRegistersSdpaSelfAttention(config) - - class Dinov2WithRegistersLayerScale(nn.Module): def __init__(self, config) -> None: super().__init__() @@ -428,12 +412,6 @@ class Dinov2WithRegistersSwiGLUFFN(nn.Module): return self.weights_out(hidden) -DINOV2_WITH_REGISTERS_ATTENTION_CLASSES = { - "eager": Dinov2WithRegistersAttention, - "sdpa": Dinov2WithRegistersSdpaAttention, -} - - class Dinov2WithRegistersLayer(nn.Module): """This corresponds to the Block class in the original implementation.""" @@ -441,7 +419,7 @@ class Dinov2WithRegistersLayer(nn.Module): super().__init__() self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = DINOV2_WITH_REGISTERS_ATTENTION_CLASSES[config._attn_implementation](config) + self.attention = Dinov2WithRegistersAttention(config) self.layer_scale1 = Dinov2WithRegistersLayerScale(config) self.drop_path = ( Dinov2WithRegistersDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() @@ -550,6 +528,7 @@ class Dinov2WithRegistersPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Dinov2WithRegistersSwiGLUFFN"] _supports_sdpa = True + _supports_flash_attn_2 = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/dpt/configuration_dpt.py b/src/transformers/models/dpt/configuration_dpt.py index 516f8f43f0..32015924cf 100644 --- a/src/transformers/models/dpt/configuration_dpt.py +++ b/src/transformers/models/dpt/configuration_dpt.py @@ -282,5 +282,9 @@ class DPTConfig(PretrainedConfig): output["model_type"] = self.__class__.model_type return output + @property + def sub_configs(self): + return {"backbone_config": type(self.backbone_config)} if self.backbone_config is not None else {} + __all__ = ["DPTConfig"] diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index e4d55603e6..66c779294a 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -20,9 +20,8 @@ https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_hea """ import collections.abc -import math from dataclasses import dataclass -from typing import List, Optional, Set, Tuple, Union +from typing import Callable, List, Optional, Set, Tuple, Union import torch import torch.utils.checkpoint @@ -37,7 +36,7 @@ from ...file_utils import ( replace_return_docstrings, ) from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, logging, torch_int from ...utils.backbone_utils import load_backbone @@ -295,8 +294,39 @@ class DPTViTPatchEmbeddings(nn.Module): return embeddings +# Copied from transformers.models.vit.modeling_vit.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DPT -class DPTViTSelfAttention(nn.Module): +class DPTSelfAttention(nn.Module): def __init__(self, config: DPTConfig) -> None: super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): @@ -305,16 +335,18 @@ class DPTViTSelfAttention(nn.Module): f"heads {config.num_attention_heads}." ) + self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) @@ -323,33 +355,33 @@ class DPTViTSelfAttention(nn.Module): def forward( self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = self.transpose_for_scores(self.query(hidden_states)) - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attention_scores = attention_scores / math.sqrt(self.attention_head_size) + context_layer, attention_probs = attention_interface( + self, + query_layer, + key_layer, + value_layer, + head_mask, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, + ) - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) + context_layer = context_layer.reshape(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) @@ -378,7 +410,7 @@ class DPTViTSelfOutput(nn.Module): class DPTViTAttention(nn.Module): def __init__(self, config: DPTConfig) -> None: super().__init__() - self.attention = DPTViTSelfAttention(config) + self.attention = DPTSelfAttention(config) self.output = DPTViTSelfOutput(config) self.pruned_heads = set() @@ -809,6 +841,8 @@ class DPTPreTrainedModel(PreTrainedModel): base_model_prefix = "dpt" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _supports_sdpa = True + _supports_flash_attn_2 = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index 7d4619480c..5c738fe07b 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -5,8 +5,7 @@ # modular_ijepa.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import collections.abc -import math -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Callable, Dict, List, Optional, Set, Tuple, Union import torch import torch.nn as nn @@ -14,7 +13,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( add_code_sample_docstrings, @@ -167,6 +166,7 @@ class IJepaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["IJepaEmbeddings", "IJepaLayer"] _supports_sdpa = True + _supports_flash_attn_2 = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" @@ -189,6 +189,36 @@ class IJepaPreTrainedModel(PreTrainedModel): ).to(module.position_embeddings.dtype) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class IJepaSelfAttention(nn.Module): def __init__(self, config: IJepaConfig) -> None: super().__init__() @@ -198,16 +228,18 @@ class IJepaSelfAttention(nn.Module): f"heads {config.num_attention_heads}." ) + self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) @@ -216,84 +248,37 @@ class IJepaSelfAttention(nn.Module): def forward( self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = self.transpose_for_scores(self.query(hidden_states)) - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - return outputs - - -class IJepaSdpaSelfAttention(IJepaSelfAttention): - def __init__(self, config: IJepaConfig) -> None: - super().__init__(config) - self.attention_probs_dropout_prob = config.attention_probs_dropout_prob - - def forward( - self, - hidden_states: torch.FloatTensor, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - if output_attentions or head_mask is not None: - logger.warning_once( - "`IJepaSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " - "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " - "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " - 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - head_mask=head_mask, - output_attentions=output_attentions, - ) - - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) - - context_layer = torch.nn.functional.scaled_dot_product_attention( + context_layer, attention_probs = attention_interface( + self, query_layer, key_layer, value_layer, head_mask, - self.attention_probs_dropout_prob if self.training else 0.0, - is_causal=False, - scale=None, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, ) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) + context_layer = context_layer.reshape(new_context_layer_shape) - return context_layer, None + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs class IJepaSelfOutput(nn.Module): @@ -353,12 +338,6 @@ class IJepaAttention(nn.Module): return outputs -class IJepaSdpaAttention(IJepaAttention): - def __init__(self, config: IJepaConfig) -> None: - super().__init__(config) - self.attention = IJepaSdpaSelfAttention(config) - - class IJepaIntermediate(nn.Module): def __init__(self, config: IJepaConfig) -> None: super().__init__() @@ -390,12 +369,6 @@ class IJepaOutput(nn.Module): return hidden_states -IJEPA_ATTENTION_CLASSES = { - "eager": IJepaAttention, - "sdpa": IJepaSdpaAttention, -} - - class IJepaLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -403,7 +376,7 @@ class IJepaLayer(nn.Module): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = IJEPA_ATTENTION_CLASSES[config._attn_implementation](config) + self.attention = IJepaAttention(config) self.intermediate = IJepaIntermediate(config) self.output = IJepaOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -531,7 +504,6 @@ IJEPA_INPUTS_DOCSTRING = r""" _EXPECTED_OUTPUT_SHAPE = [1, 256, 1280] - IJEPA_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and diff --git a/src/transformers/models/ijepa/modular_ijepa.py b/src/transformers/models/ijepa/modular_ijepa.py index 3b3756dd5c..447347a4ec 100644 --- a/src/transformers/models/ijepa/modular_ijepa.py +++ b/src/transformers/models/ijepa/modular_ijepa.py @@ -108,6 +108,7 @@ class IJepaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["IJepaEmbeddings", "IJepaLayer"] _supports_sdpa = True + _supports_flash_attn_2 = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 0e51cd9886..8174a430b7 100755 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -15,10 +15,9 @@ """PyTorch VideoMAE (masked autoencoder) model.""" import collections.abc -import math from copy import deepcopy from dataclasses import dataclass -from typing import Optional, Set, Tuple, Union +from typing import Callable, Optional, Set, Tuple, Union import numpy as np import torch @@ -28,7 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, @@ -196,6 +195,37 @@ class VideoMAEPatchEmbeddings(nn.Module): return embeddings +# Copied from transformers.models.vit.modeling_vit.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class VideoMAESelfAttention(nn.Module): def __init__(self, config: VideoMAEConfig) -> None: super().__init__() @@ -204,10 +234,13 @@ class VideoMAESelfAttention(nn.Module): f"The hidden size {config.hidden_size} is not a multiple of the number of attention " f"heads {config.num_attention_heads}." ) - + self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=False) self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False) @@ -220,8 +253,6 @@ class VideoMAESelfAttention(nn.Module): self.q_bias = None self.v_bias = None - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) @@ -239,65 +270,33 @@ class VideoMAESelfAttention(nn.Module): value_layer = self.transpose_for_scores(values) query_layer = self.transpose_for_scores(queries) - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - return outputs - - -class VideoMAESdpaSelfAttention(VideoMAESelfAttention): - def __init__(self, config: VideoMAEConfig) -> None: - super().__init__(config) - self.attention_probs_dropout_prob = config.attention_probs_dropout_prob - - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False - ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - k_bias = torch.zeros_like(self.v_bias, requires_grad=False) if self.q_bias is not None else None - keys = nn.functional.linear(input=hidden_states, weight=self.key.weight, bias=k_bias) - values = nn.functional.linear(input=hidden_states, weight=self.value.weight, bias=self.v_bias) - queries = nn.functional.linear(input=hidden_states, weight=self.query.weight, bias=self.q_bias) - - key_layer = self.transpose_for_scores(keys) - value_layer = self.transpose_for_scores(values) - query_layer = self.transpose_for_scores(queries) - - context_layer = torch.nn.functional.scaled_dot_product_attention( + context_layer, attention_probs = attention_interface( + self, query_layer, key_layer, value_layer, head_mask, - self.attention_probs_dropout_prob if self.training else 0.0, - is_causal=False, - scale=None, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, ) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) + context_layer = context_layer.reshape(new_context_layer_shape) - return context_layer, None + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VideoMAE @@ -359,13 +358,6 @@ class VideoMAEAttention(nn.Module): return outputs -# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->VideoMAE -class VideoMAESdpaAttention(VideoMAEAttention): - def __init__(self, config: VideoMAEConfig) -> None: - super().__init__(config) - self.attention = VideoMAESdpaSelfAttention(config) - - # Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->VideoMAE class VideoMAEIntermediate(nn.Module): def __init__(self, config: VideoMAEConfig) -> None: @@ -399,9 +391,6 @@ class VideoMAEOutput(nn.Module): return hidden_states -VIDEOMAE_ATTENTION_CLASSES = {"eager": VideoMAEAttention, "sdpa": VideoMAESdpaAttention} - - # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->VideoMAE,VIT->VIDEOMAE class VideoMAELayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -410,7 +399,7 @@ class VideoMAELayer(nn.Module): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = VIDEOMAE_ATTENTION_CLASSES[config._attn_implementation](config) + self.attention = VideoMAEAttention(config) self.intermediate = VideoMAEIntermediate(config) self.output = VideoMAEOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -508,6 +497,7 @@ class VideoMAEPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _supports_sdpa = True + _supports_flash_attn_2 = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 2fd430c101..2d392ccaf9 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -16,7 +16,7 @@ import collections.abc import math -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Callable, Dict, List, Optional, Set, Tuple, Union import torch import torch.utils.checkpoint @@ -30,7 +30,7 @@ from ...modeling_outputs import ( ImageClassifierOutput, MaskedImageModelingOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( add_code_sample_docstrings, @@ -184,6 +184,36 @@ class ViTPatchEmbeddings(nn.Module): return embeddings +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class ViTSelfAttention(nn.Module): def __init__(self, config: ViTConfig) -> None: super().__init__() @@ -193,16 +223,18 @@ class ViTSelfAttention(nn.Module): f"heads {config.num_attention_heads}." ) + self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) @@ -211,84 +243,37 @@ class ViTSelfAttention(nn.Module): def forward( self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = self.transpose_for_scores(self.query(hidden_states)) - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - return outputs - - -class ViTSdpaSelfAttention(ViTSelfAttention): - def __init__(self, config: ViTConfig) -> None: - super().__init__(config) - self.attention_probs_dropout_prob = config.attention_probs_dropout_prob - - def forward( - self, - hidden_states: torch.FloatTensor, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - if output_attentions or head_mask is not None: - logger.warning_once( - "`ViTSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " - "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " - "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " - 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - head_mask=head_mask, - output_attentions=output_attentions, - ) - - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) - - context_layer = torch.nn.functional.scaled_dot_product_attention( + context_layer, attention_probs = attention_interface( + self, query_layer, key_layer, value_layer, head_mask, - self.attention_probs_dropout_prob if self.training else 0.0, - is_causal=False, - scale=None, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, ) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) + context_layer = context_layer.reshape(new_context_layer_shape) - return context_layer, None + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs class ViTSelfOutput(nn.Module): @@ -348,12 +333,6 @@ class ViTAttention(nn.Module): return outputs -class ViTSdpaAttention(ViTAttention): - def __init__(self, config: ViTConfig) -> None: - super().__init__(config) - self.attention = ViTSdpaSelfAttention(config) - - class ViTIntermediate(nn.Module): def __init__(self, config: ViTConfig) -> None: super().__init__() @@ -385,12 +364,6 @@ class ViTOutput(nn.Module): return hidden_states -VIT_ATTENTION_CLASSES = { - "eager": ViTAttention, - "sdpa": ViTSdpaAttention, -} - - class ViTLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -398,7 +371,7 @@ class ViTLayer(nn.Module): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = VIT_ATTENTION_CLASSES[config._attn_implementation](config) + self.attention = ViTAttention(config) self.intermediate = ViTIntermediate(config) self.output = ViTOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -496,6 +469,7 @@ class ViTPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["ViTEmbeddings", "ViTLayer"] _supports_sdpa = True + _supports_flash_attn_2 = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index 86e71155d9..c002c41ca0 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -15,10 +15,9 @@ """PyTorch ViT MAE (masked autoencoder) model.""" import collections.abc -import math from copy import deepcopy from dataclasses import dataclass -from typing import Optional, Set, Tuple, Union +from typing import Callable, Optional, Set, Tuple, Union import numpy as np import torch @@ -27,7 +26,7 @@ from torch import nn from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, @@ -356,6 +355,37 @@ class ViTMAEPatchEmbeddings(nn.Module): return x +# Copied from transformers.models.vit.modeling_vit.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention ViT->ViTMAE class ViTMAESelfAttention(nn.Module): def __init__(self, config: ViTMAEConfig) -> None: @@ -366,16 +396,18 @@ class ViTMAESelfAttention(nn.Module): f"heads {config.num_attention_heads}." ) + self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) @@ -384,85 +416,37 @@ class ViTMAESelfAttention(nn.Module): def forward( self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = self.transpose_for_scores(self.query(hidden_states)) - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - return outputs - - -# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention ViT->ViTMAE -class ViTMAESdpaSelfAttention(ViTMAESelfAttention): - def __init__(self, config: ViTMAEConfig) -> None: - super().__init__(config) - self.attention_probs_dropout_prob = config.attention_probs_dropout_prob - - def forward( - self, - hidden_states: torch.FloatTensor, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - if output_attentions or head_mask is not None: - logger.warning_once( - "`ViTMAESdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " - "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " - "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " - 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - head_mask=head_mask, - output_attentions=output_attentions, - ) - - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) - - context_layer = torch.nn.functional.scaled_dot_product_attention( + context_layer, attention_probs = attention_interface( + self, query_layer, key_layer, value_layer, head_mask, - self.attention_probs_dropout_prob if self.training else 0.0, - is_causal=False, - scale=None, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, ) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) + context_layer = context_layer.reshape(new_context_layer_shape) - return context_layer, None + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMAE @@ -524,13 +508,6 @@ class ViTMAEAttention(nn.Module): return outputs -# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTMAE -class ViTMAESdpaAttention(ViTMAEAttention): - def __init__(self, config: ViTMAEConfig) -> None: - super().__init__(config) - self.attention = ViTMAESdpaSelfAttention(config) - - # Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->ViTMAE class ViTMAEIntermediate(nn.Module): def __init__(self, config: ViTMAEConfig) -> None: @@ -564,12 +541,6 @@ class ViTMAEOutput(nn.Module): return hidden_states -VITMAE_ATTENTION_CLASSES = { - "eager": ViTMAEAttention, - "sdpa": ViTMAESdpaAttention, -} - - # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE,VIT->VITMAE class ViTMAELayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -578,7 +549,7 @@ class ViTMAELayer(nn.Module): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = VITMAE_ATTENTION_CLASSES[config._attn_implementation](config) + self.attention = ViTMAEAttention(config) self.intermediate = ViTMAEIntermediate(config) self.output = ViTMAEOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -676,6 +647,7 @@ class ViTMAEPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _supports_sdpa = True + _supports_flash_attn_2 = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 79021a6b8b..8f25438ef9 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -15,8 +15,7 @@ """PyTorch ViT MSN (masked siamese network) model.""" import collections.abc -import math -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Callable, Dict, List, Optional, Set, Tuple, Union import torch import torch.utils.checkpoint @@ -25,7 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( add_start_docstrings, @@ -173,6 +172,37 @@ class ViTMSNPatchEmbeddings(nn.Module): return embeddings +# Copied from transformers.models.vit.modeling_vit.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->ViTMSN class ViTMSNSelfAttention(nn.Module): def __init__(self, config: ViTMSNConfig) -> None: @@ -183,16 +213,18 @@ class ViTMSNSelfAttention(nn.Module): f"heads {config.num_attention_heads}." ) + self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) @@ -201,85 +233,37 @@ class ViTMSNSelfAttention(nn.Module): def forward( self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = self.transpose_for_scores(self.query(hidden_states)) - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - return outputs - - -# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->ViTMSN -class ViTMSNSdpaSelfAttention(ViTMSNSelfAttention): - def __init__(self, config: ViTMSNConfig) -> None: - super().__init__(config) - self.attention_probs_dropout_prob = config.attention_probs_dropout_prob - - def forward( - self, - hidden_states: torch.FloatTensor, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - if output_attentions or head_mask is not None: - logger.warning_once( - "`ViTMSNSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " - "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " - "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " - 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - head_mask=head_mask, - output_attentions=output_attentions, - ) - - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) - - context_layer = torch.nn.functional.scaled_dot_product_attention( + context_layer, attention_probs = attention_interface( + self, query_layer, key_layer, value_layer, head_mask, - self.attention_probs_dropout_prob if self.training else 0.0, - is_causal=False, - scale=None, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, ) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) + context_layer = context_layer.reshape(new_context_layer_shape) - return context_layer, None + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMSN @@ -341,13 +325,6 @@ class ViTMSNAttention(nn.Module): return outputs -# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTMSN -class ViTMSNSdpaAttention(ViTMSNAttention): - def __init__(self, config: ViTMSNConfig) -> None: - super().__init__(config) - self.attention = ViTMSNSdpaSelfAttention(config) - - # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTMSN class ViTMSNIntermediate(nn.Module): def __init__(self, config: ViTMSNConfig) -> None: @@ -381,9 +358,6 @@ class ViTMSNOutput(nn.Module): return hidden_states -VITMSN_ATTENTION_CLASSES = {"eager": ViTMSNAttention, "sdpa": ViTMSNSdpaAttention} - - # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMSN, VIT->VITMSN class ViTMSNLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -392,7 +366,7 @@ class ViTMSNLayer(nn.Module): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = VITMSN_ATTENTION_CLASSES[config._attn_implementation](config) + self.attention = ViTMSNAttention(config) self.intermediate = ViTMSNIntermediate(config) self.output = ViTMSNOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -491,6 +465,7 @@ class ViTMSNPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["ViTMSNAttention", "ViTMSNSdpaAttention"] _supports_sdpa = True + _supports_flash_attn_2 = True # todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211 # when creating pre-training scripts. diff --git a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py index b4a1acd336..c0d6d7f022 100644 --- a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +++ b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py @@ -20,8 +20,7 @@ This code is the same as the original Vision Transformer (ViT) with 2 modificati """ import collections.abc -import math -from typing import Optional, Set, Tuple, Union +from typing import Callable, Optional, Set, Tuple, Union import torch import torch.utils.checkpoint @@ -29,7 +28,7 @@ from torch import nn from ...activations import ACT2FN from ...modeling_outputs import BackboneOutput, BaseModelOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( add_start_docstrings, @@ -103,6 +102,37 @@ class VitPoseBackboneEmbeddings(nn.Module): return embeddings +# Copied from transformers.models.vit.modeling_vit.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->VitPoseBackbone class VitPoseBackboneSelfAttention(nn.Module): def __init__(self, config: VitPoseBackboneConfig) -> None: @@ -113,16 +143,18 @@ class VitPoseBackboneSelfAttention(nn.Module): f"heads {config.num_attention_heads}." ) + self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) @@ -131,33 +163,33 @@ class VitPoseBackboneSelfAttention(nn.Module): def forward( self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = self.transpose_for_scores(self.query(hidden_states)) - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attention_scores = attention_scores / math.sqrt(self.attention_head_size) + context_layer, attention_probs = attention_interface( + self, + query_layer, + key_layer, + value_layer, + head_mask, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, + ) - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) + context_layer = context_layer.reshape(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) @@ -392,6 +424,8 @@ class VitPoseBackbonePreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = ["VitPoseBackboneEmbeddings", "VitPoseBackboneLayer"] + _supports_sdpa = True + _supports_flash_attn_2 = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm, VitPoseBackboneEmbeddings]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index 4ef0f29bc8..bd6ce5234f 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -14,8 +14,7 @@ # limitations under the License. """PyTorch ViViT model.""" -import math -from typing import Optional, Set, Tuple, Union +from typing import Callable, Optional, Set, Tuple, Union import torch import torch.utils.checkpoint @@ -24,7 +23,7 @@ from torch.nn import CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( add_start_docstrings, @@ -166,6 +165,37 @@ class VivitEmbeddings(nn.Module): return embeddings +# Copied from transformers.models.vit.modeling_vit.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Vivit class VivitSelfAttention(nn.Module): def __init__(self, config: VivitConfig) -> None: @@ -176,16 +206,18 @@ class VivitSelfAttention(nn.Module): f"heads {config.num_attention_heads}." ) + self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) @@ -194,82 +226,37 @@ class VivitSelfAttention(nn.Module): def forward( self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = self.transpose_for_scores(self.query(hidden_states)) - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - return outputs - - -# Adapted from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Vivit -class VivitSdpaSelfAttention(VivitSelfAttention): - def __init__(self, config: VivitConfig) -> None: - super().__init__(config) - self.attention_probs_dropout_prob = config.attention_probs_dropout_prob - - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False - ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - if output_attentions or head_mask is not None: - logger.warning_once( - "VivitSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support" - " `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but specifying" - " the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be" - ' removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - head_mask, - output_attentions, - ) - - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) - - context_layer = torch.nn.functional.scaled_dot_product_attention( + context_layer, attention_probs = attention_interface( + self, query_layer, key_layer, value_layer, head_mask, - self.attention_probs_dropout_prob if self.training else 0.0, - is_causal=False, - scale=None, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, ) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) + context_layer = context_layer.reshape(new_context_layer_shape) - return context_layer, None + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Vivit @@ -331,13 +318,6 @@ class VivitAttention(nn.Module): return outputs -# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Vivit -class VivitSdpaAttention(VivitAttention): - def __init__(self, config: VivitConfig) -> None: - super().__init__(config) - self.attention = VivitSdpaSelfAttention(config) - - class VivitIntermediate(nn.Module): def __init__(self, config): super().__init__() @@ -372,12 +352,6 @@ class VivitOutput(nn.Module): return hidden_states -VIVIT_ATTENTION_CLASSES = { - "eager": VivitAttention, - "sdpa": VivitSdpaAttention, -} - - class VivitLayer(nn.Module): """This corresponds to the EncoderBlock class in the scenic/vivit implementation.""" @@ -385,7 +359,7 @@ class VivitLayer(nn.Module): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = VIVIT_ATTENTION_CLASSES[config._attn_implementation](config) + self.attention = VivitAttention(config) self.intermediate = VivitIntermediate(config) self.output = VivitOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -495,6 +469,7 @@ class VivitPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = [] _supports_sdpa = True + _supports_flash_attn_2 = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index 5801e0bca2..06edd9b4e5 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -15,9 +15,8 @@ """PyTorch YOLOS model.""" import collections.abc -import math from dataclasses import dataclass -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Callable, Dict, List, Optional, Set, Tuple, Union import torch import torch.utils.checkpoint @@ -25,7 +24,7 @@ from torch import nn from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, @@ -231,6 +230,37 @@ class YolosPatchEmbeddings(nn.Module): return embeddings +# Copied from transformers.models.vit.modeling_vit.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Yolos class YolosSelfAttention(nn.Module): def __init__(self, config: YolosConfig) -> None: @@ -241,16 +271,18 @@ class YolosSelfAttention(nn.Module): f"heads {config.num_attention_heads}." ) + self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) @@ -259,85 +291,37 @@ class YolosSelfAttention(nn.Module): def forward( self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = self.transpose_for_scores(self.query(hidden_states)) - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - return outputs - - -# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Yolos -class YolosSdpaSelfAttention(YolosSelfAttention): - def __init__(self, config: YolosConfig) -> None: - super().__init__(config) - self.attention_probs_dropout_prob = config.attention_probs_dropout_prob - - def forward( - self, - hidden_states: torch.FloatTensor, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - if output_attentions or head_mask is not None: - logger.warning_once( - "`YolosSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " - "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " - "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " - 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - head_mask=head_mask, - output_attentions=output_attentions, - ) - - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) - - context_layer = torch.nn.functional.scaled_dot_product_attention( + context_layer, attention_probs = attention_interface( + self, query_layer, key_layer, value_layer, head_mask, - self.attention_probs_dropout_prob if self.training else 0.0, - is_causal=False, - scale=None, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, ) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) + context_layer = context_layer.reshape(new_context_layer_shape) - return context_layer, None + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Yolos @@ -399,13 +383,6 @@ class YolosAttention(nn.Module): return outputs -# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Yolos -class YolosSdpaAttention(YolosAttention): - def __init__(self, config: YolosConfig) -> None: - super().__init__(config) - self.attention = YolosSdpaSelfAttention(config) - - # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Yolos class YolosIntermediate(nn.Module): def __init__(self, config: YolosConfig) -> None: @@ -439,9 +416,6 @@ class YolosOutput(nn.Module): return hidden_states -YOLOS_ATTENTION_CLASSES = {"eager": YolosAttention, "sdpa": YolosSdpaAttention} - - # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->Yolos,VIT->YOLOS class YolosLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" @@ -450,7 +424,7 @@ class YolosLayer(nn.Module): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = YOLOS_ATTENTION_CLASSES[config._attn_implementation](config) + self.attention = YolosAttention(config) self.intermediate = YolosIntermediate(config) self.output = YolosOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -575,6 +549,7 @@ class YolosPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = [] _supports_sdpa = True + _supports_flash_attn_2 = True def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/zoedepth/modeling_zoedepth.py b/src/transformers/models/zoedepth/modeling_zoedepth.py index 81eca0e3bf..57cad7328f 100644 --- a/src/transformers/models/zoedepth/modeling_zoedepth.py +++ b/src/transformers/models/zoedepth/modeling_zoedepth.py @@ -1219,7 +1219,8 @@ class ZoeDepthMetricDepthEstimationHead(nn.Module): return out, None -# Copied from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->ZoeDepth,dpt->zoedepth +# Modified from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->ZoeDepth,dpt->zoedepth +# avoiding sdpa and flash_attn_2 support, it's done int the backend class ZoeDepthPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained diff --git a/tests/models/dpt/test_modeling_dpt.py b/tests/models/dpt/test_modeling_dpt.py index fc652c2484..da40466d48 100644 --- a/tests/models/dpt/test_modeling_dpt.py +++ b/tests/models/dpt/test_modeling_dpt.py @@ -255,6 +255,10 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): def test_training_gradient_checkpointing_use_reentrant_false(self): pass + @unittest.skip(reason="Inductor error for dynamic shape") + def test_sdpa_can_compile_dynamic(self): + pass + def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/videomae/test_modeling_videomae.py b/tests/models/videomae/test_modeling_videomae.py index 2d56bbd551..4b1abab206 100644 --- a/tests/models/videomae/test_modeling_videomae.py +++ b/tests/models/videomae/test_modeling_videomae.py @@ -15,14 +15,24 @@ """Testing suite for the PyTorch VideoMAE model.""" import copy +import tempfile import unittest import numpy as np from huggingface_hub import hf_hub_download +from pytest import mark from transformers import VideoMAEConfig from transformers.models.auto import get_values -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import ( + is_flaky, + require_flash_attn, + require_torch, + require_torch_gpu, + require_vision, + slow, + torch_device, +) from transformers.utils import cached_property, is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester @@ -338,6 +348,59 @@ class VideoMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase check_hidden_states_output(inputs_dict, config, model_class) + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + @is_flaky() + def test_flash_attn_2_inference_equivalence(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + inputs_dict = self._prepare_for_class(inputs_dict, model_class) + inputs_dict["pixel_values"] = inputs_dict["pixel_values"].to(torch.bfloat16) + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + outputs = model(**inputs_dict, output_hidden_states=True) + outputs_fa = model_fa(**inputs_dict, output_hidden_states=True) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + # check with inference + dropout + model.train() + _ = model_fa(**inputs_dict) + + @unittest.skip("Not applicable for VideoMAE") + def test_flash_attn_2_inference_equivalence_right_padding(self): + pass + # We will verify our results on a video of eating spaghetti # Frame indices used: [164 168 172 176 181 185 189 193 198 202 206 210 215 219 223 227] diff --git a/tests/models/vit_mae/test_modeling_vit_mae.py b/tests/models/vit_mae/test_modeling_vit_mae.py index d2f0424889..177ddc269d 100644 --- a/tests/models/vit_mae/test_modeling_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_vit_mae.py @@ -19,9 +19,18 @@ import tempfile import unittest import numpy as np +from pytest import mark from transformers import ViTMAEConfig -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import ( + is_flaky, + require_flash_attn, + require_torch, + require_torch_gpu, + require_vision, + slow, + torch_device, +) from transformers.utils import cached_property, is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester @@ -269,6 +278,63 @@ class ViTMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): model = ViTMAEModel.from_pretrained(model_name) self.assertIsNotNone(model) + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + @is_flaky() + def test_flash_attn_2_inference_equivalence(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + inputs_dict = self._prepare_for_class(inputs_dict, model_class) + inputs_dict["pixel_values"] = inputs_dict["pixel_values"].to(torch.bfloat16) + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + # ForPretraining model has random `noise` -> need to set seed + # to make the test deterministic + torch.manual_seed(12345) + outputs = model(**inputs_dict, output_hidden_states=True) + torch.manual_seed(12345) + outputs_fa = model_fa(**inputs_dict, output_hidden_states=True) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + # check with inference + dropout + model.train() + _ = model_fa(**inputs_dict) + + @unittest.skip("Not applicable for VideoMAE") + def test_flash_attn_2_inference_equivalence_right_padding(self): + pass + # We will verify our results on an image of cute cats def prepare_img(): diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index d37dd92c71..c16a888885 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -130,7 +130,7 @@ class ConfigTester: general_config_dict = config.to_dict() # Iterate over all sub_configs if there are any and load them with their own classes - sub_configs = self.config_class.sub_configs + sub_configs = general_config_loaded.sub_configs for sub_config_key, sub_class in sub_configs.items(): if sub_class.__name__ == "AutoConfig": sub_class = sub_class.for_model(**general_config_dict[sub_config_key]).__class__ diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 817c5208d0..58c09d4178 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -315,8 +315,6 @@ class ModelTesterMixin: return inputs_dict def test_save_load(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - def check_save_load(out1, out2): # make sure we don't have nans out_2 = out2.cpu().numpy() @@ -330,6 +328,7 @@ class ModelTesterMixin: self.assertLessEqual(max_diff, 1e-5) for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) model.to(torch_device) model.eval() @@ -508,16 +507,16 @@ class ModelTesterMixin: @is_flaky(description="low likelihood of failure, reason not yet discovered") def test_save_load_fast_init_from_base(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - if config.__class__ not in MODEL_MAPPING: - self.skipTest(reason=f"{config.__class__.__name__} not in MODEL_MAPPING") - - base_class = MODEL_MAPPING[config.__class__] - - if isinstance(base_class, tuple): - base_class = base_class[0] - for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if config.__class__ not in MODEL_MAPPING: + self.skipTest(reason=f"{config.__class__.__name__} not in MODEL_MAPPING") + + base_class = MODEL_MAPPING[config.__class__] + + if isinstance(base_class, tuple): + base_class = base_class[0] + if model_class == base_class: continue @@ -2228,9 +2227,9 @@ class ModelTesterMixin: def test_correct_missing_keys(self): if not self.test_missing_keys: self.skipTest(reason="test_missing_keys is set to `False`") - config, _ = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: + config, _ = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) base_model_prefix = model.base_model_prefix @@ -2287,8 +2286,8 @@ class ModelTesterMixin: @require_safetensors def test_can_use_safetensors(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: + config, _ = self.model_tester.prepare_config_and_inputs_for_common() model_tied = model_class(config) with tempfile.TemporaryDirectory() as d: try: @@ -2323,9 +2322,9 @@ class ModelTesterMixin: ) def test_load_save_without_tied_weights(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - config.tie_word_embeddings = False for model_class in self.all_model_classes: + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + config.tie_word_embeddings = False model = model_class(config) with tempfile.TemporaryDirectory() as d: model.save_pretrained(d) @@ -2373,8 +2372,8 @@ class ModelTesterMixin: ) def test_model_weights_reload_no_missing_tied_weights(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: + config, _ = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained(tmp_dir)