Refactor Attention implementation for ViT-based models (#36545)

* Refactor vit attention

* Refactor ViT-based models

* 🚨🚨🚨 Fix prefix for DPT

* Update params order

* trigger tests

* Fix Dinov2 attention

* Fix DPT attention impl propagation for backbone config

* Common test fix: config is modif. inplace - avoid it

* view->reshape

* Fixup

* Fixup

* Enable IJepa FA2

* Add FA2 in corresponding model docs
This commit is contained in:
Pavel Iakubovskii
2025-03-20 15:15:01 +00:00
committed by GitHub
parent 730d2a52e7
commit 66291778dd
35 changed files with 932 additions and 975 deletions

View File

@@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>

View File

@@ -19,6 +19,7 @@ rendered properly in your Markdown viewer.
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>

View File

@@ -16,6 +16,7 @@ specific language governing permissions and limitations under the License.
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>

View File

@@ -11,6 +11,7 @@ specific language governing permissions and limitations under the License.
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>

View File

@@ -18,6 +18,8 @@ rendered properly in your Markdown viewer.
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
## Overview

View File

@@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>

View File

@@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>

View File

@@ -21,6 +21,7 @@ rendered properly in your Markdown viewer.
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>

View File

@@ -19,6 +19,7 @@ rendered properly in your Markdown viewer.
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>

View File

@@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>

View File

@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>

View File

@@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>

View File

@@ -2098,6 +2098,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if not isinstance(requested_attn_implementation, dict)
else requested_attn_implementation.get(key, None)
)
# For models with backbone sub-config might be not initialized
if sub_config is not None:
sub_config._attn_implementation_internal = curr_attn_implementation
if use_flash_attention_2:

View File

@@ -14,8 +14,7 @@
# limitations under the License.
"""PyTorch Audio Spectrogram Transformer (AST) model."""
import math
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -24,7 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_audio_spectrogram_transformer import ASTConfig
@@ -108,6 +107,37 @@ class ASTPatchEmbeddings(nn.Module):
return embeddings
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
# Normalize the attention scores to probabilities.
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
# Mask heads if we want to
if attention_mask is not None:
attn_weights = attn_weights * attention_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->AST
class ASTSelfAttention(nn.Module):
def __init__(self, config: ASTConfig) -> None:
@@ -118,16 +148,18 @@ class ASTSelfAttention(nn.Module):
f"heads {config.num_attention_heads}."
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dropout_prob = config.attention_probs_dropout_prob
self.scaling = self.attention_head_size**-0.5
self.is_causal = False
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
@@ -136,85 +168,37 @@ class ASTSelfAttention(nn.Module):
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
query_layer = self.transpose_for_scores(self.query(hidden_states))
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->AST
class ASTSdpaSelfAttention(ASTSelfAttention):
def __init__(self, config: ASTConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self,
hidden_states: torch.FloatTensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions or head_mask is not None:
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`ASTSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
head_mask=head_mask,
output_attentions=output_attentions,
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
context_layer, attention_probs = attention_interface(
self,
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
is_causal=self.is_causal,
scaling=self.scaling,
dropout=0.0 if not self.training else self.dropout_prob,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
context_layer = context_layer.reshape(new_context_layer_shape)
return context_layer, None
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->AST
@@ -276,13 +260,6 @@ class ASTAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->AST
class ASTSdpaAttention(ASTAttention):
def __init__(self, config: ASTConfig) -> None:
super().__init__(config)
self.attention = ASTSdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->AST
class ASTIntermediate(nn.Module):
def __init__(self, config: ASTConfig) -> None:
@@ -316,12 +293,6 @@ class ASTOutput(nn.Module):
return hidden_states
AST_ATTENTION_CLASSES = {
"eager": ASTAttention,
"sdpa": ASTSdpaAttention,
}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST,VIT->AST
class ASTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
@@ -330,7 +301,7 @@ class ASTLayer(nn.Module):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = AST_ATTENTION_CLASSES[config._attn_implementation](config)
self.attention = ASTAttention(config)
self.intermediate = ASTIntermediate(config)
self.output = ASTOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -428,6 +399,7 @@ class ASTPreTrainedModel(PreTrainedModel):
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_sdpa = True
_supports_flash_attn_2 = True
# Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:

View File

@@ -15,9 +15,8 @@
"""PyTorch DeiT model."""
import collections.abc
import math
from dataclasses import dataclass
from typing import Optional, Set, Tuple, Union
from typing import Callable, Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -31,7 +30,7 @@ from ...modeling_outputs import (
ImageClassifierOutput,
MaskedImageModelingOutput,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
@@ -180,6 +179,37 @@ class DeiTPatchEmbeddings(nn.Module):
return x
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
# Normalize the attention scores to probabilities.
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
# Mask heads if we want to
if attention_mask is not None:
attn_weights = attn_weights * attention_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DeiT
class DeiTSelfAttention(nn.Module):
def __init__(self, config: DeiTConfig) -> None:
@@ -190,16 +220,18 @@ class DeiTSelfAttention(nn.Module):
f"heads {config.num_attention_heads}."
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dropout_prob = config.attention_probs_dropout_prob
self.scaling = self.attention_head_size**-0.5
self.is_causal = False
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
@@ -208,85 +240,37 @@ class DeiTSelfAttention(nn.Module):
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
query_layer = self.transpose_for_scores(self.query(hidden_states))
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->DeiT
class DeiTSdpaSelfAttention(DeiTSelfAttention):
def __init__(self, config: DeiTConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self,
hidden_states: torch.FloatTensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions or head_mask is not None:
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`DeiTSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
head_mask=head_mask,
output_attentions=output_attentions,
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
context_layer, attention_probs = attention_interface(
self,
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
is_causal=self.is_causal,
scaling=self.scaling,
dropout=0.0 if not self.training else self.dropout_prob,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
context_layer = context_layer.reshape(new_context_layer_shape)
return context_layer, None
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->DeiT
@@ -348,13 +332,6 @@ class DeiTAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->DeiT
class DeiTSdpaAttention(DeiTAttention):
def __init__(self, config: DeiTConfig) -> None:
super().__init__(config)
self.attention = DeiTSdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT
class DeiTIntermediate(nn.Module):
def __init__(self, config: DeiTConfig) -> None:
@@ -388,12 +365,6 @@ class DeiTOutput(nn.Module):
return hidden_states
DEIT_ATTENTION_CLASSES = {
"eager": DeiTAttention,
"sdpa": DeiTSdpaAttention,
}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT,VIT->DEIT
class DeiTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
@@ -402,7 +373,7 @@ class DeiTLayer(nn.Module):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = DEIT_ATTENTION_CLASSES[config._attn_implementation](config)
self.attention = DeiTAttention(config)
self.intermediate = DeiTIntermediate(config)
self.output = DeiTOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -501,6 +472,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["DeiTLayer"]
_supports_sdpa = True
_supports_flash_attn_2 = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""

View File

@@ -240,7 +240,8 @@ class DepthAnythingFeatureFusionStage(nn.Module):
return fused_hidden_states
# Copied from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->DepthAnything,dpt->depth_anything
# Modified from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->DepthAnything,dpt->depth_anything
# avoiding sdpa and flash_attn_2 support, it's done in the backend
class DepthAnythingPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained

View File

@@ -15,8 +15,7 @@
"""PyTorch DINOv2 model."""
import collections.abc
import math
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -30,7 +29,7 @@ from ...modeling_outputs import (
BaseModelOutputWithPooling,
ImageClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
@@ -172,6 +171,37 @@ class Dinov2PatchEmbeddings(nn.Module):
return embeddings
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
# Normalize the attention scores to probabilities.
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
# Mask heads if we want to
if attention_mask is not None:
attn_weights = attn_weights * attention_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2
class Dinov2SelfAttention(nn.Module):
def __init__(self, config: Dinov2Config) -> None:
@@ -182,16 +212,18 @@ class Dinov2SelfAttention(nn.Module):
f"heads {config.num_attention_heads}."
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dropout_prob = config.attention_probs_dropout_prob
self.scaling = self.attention_head_size**-0.5
self.is_causal = False
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
@@ -200,78 +232,37 @@ class Dinov2SelfAttention(nn.Module):
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
query_layer = self.transpose_for_scores(self.query(hidden_states))
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class Dinov2SdpaSelfAttention(Dinov2SelfAttention):
def __init__(self, config: Dinov2Config) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"Dinov2Model is using Dinov2SdpaSelfAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states, head_mask=head_mask, output_attentions=output_attentions
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
context_layer, attention_probs = attention_interface(
self,
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
is_causal=self.is_causal,
scaling=self.scaling,
dropout=0.0 if not self.training else self.dropout_prob,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
context_layer = context_layer.reshape(new_context_layer_shape)
return context_layer, None
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
@@ -333,13 +324,6 @@ class Dinov2Attention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Dinov2
class Dinov2SdpaAttention(Dinov2Attention):
def __init__(self, config: Dinov2Config) -> None:
super().__init__(config)
self.attention = Dinov2SdpaSelfAttention(config)
class Dinov2LayerScale(nn.Module):
def __init__(self, config) -> None:
super().__init__()
@@ -421,12 +405,6 @@ class Dinov2SwiGLUFFN(nn.Module):
return self.weights_out(hidden)
DINOV2_ATTENTION_CLASSES = {
"eager": Dinov2Attention,
"sdpa": Dinov2SdpaAttention,
}
class Dinov2Layer(nn.Module):
"""This corresponds to the Block class in the original implementation."""
@@ -434,7 +412,7 @@ class Dinov2Layer(nn.Module):
super().__init__()
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = DINOV2_ATTENTION_CLASSES[config._attn_implementation](config)
self.attention = Dinov2Attention(config)
self.layer_scale1 = Dinov2LayerScale(config)
self.drop_path = Dinov2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
@@ -542,6 +520,7 @@ class Dinov2PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Dinov2SwiGLUFFN"]
_supports_sdpa = True
_supports_flash_attn_2 = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""

View File

@@ -21,8 +21,7 @@
# limitations under the License.
import collections.abc
import math
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import torch
from torch import nn
@@ -30,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
@@ -185,6 +184,36 @@ class Dinov2WithRegistersEmbeddings(nn.Module):
return embeddings
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
# Normalize the attention scores to probabilities.
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
# Mask heads if we want to
if attention_mask is not None:
attn_weights = attn_weights * attention_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class Dinov2WithRegistersSelfAttention(nn.Module):
def __init__(self, config: Dinov2WithRegistersConfig) -> None:
super().__init__()
@@ -194,16 +223,18 @@ class Dinov2WithRegistersSelfAttention(nn.Module):
f"heads {config.num_attention_heads}."
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dropout_prob = config.attention_probs_dropout_prob
self.scaling = self.attention_head_size**-0.5
self.is_causal = False
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
@@ -212,78 +243,37 @@ class Dinov2WithRegistersSelfAttention(nn.Module):
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
query_layer = self.transpose_for_scores(self.query(hidden_states))
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class Dinov2WithRegistersSdpaSelfAttention(Dinov2WithRegistersSelfAttention):
def __init__(self, config: Dinov2WithRegistersConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"Dinov2WithRegistersModel is using Dinov2WithRegistersSdpaSelfAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states, head_mask=head_mask, output_attentions=output_attentions
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
context_layer, attention_probs = attention_interface(
self,
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
is_causal=self.is_causal,
scaling=self.scaling,
dropout=0.0 if not self.training else self.dropout_prob,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
context_layer = context_layer.reshape(new_context_layer_shape)
return context_layer, None
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class Dinov2WithRegistersSelfOutput(nn.Module):
@@ -343,12 +333,6 @@ class Dinov2WithRegistersAttention(nn.Module):
return outputs
class Dinov2WithRegistersSdpaAttention(Dinov2WithRegistersAttention):
def __init__(self, config: Dinov2WithRegistersConfig) -> None:
super().__init__(config)
self.attention = Dinov2WithRegistersSdpaSelfAttention(config)
class Dinov2WithRegistersLayerScale(nn.Module):
def __init__(self, config) -> None:
super().__init__()
@@ -428,12 +412,6 @@ class Dinov2WithRegistersSwiGLUFFN(nn.Module):
return self.weights_out(hidden)
DINOV2_WITH_REGISTERS_ATTENTION_CLASSES = {
"eager": Dinov2WithRegistersAttention,
"sdpa": Dinov2WithRegistersSdpaAttention,
}
class Dinov2WithRegistersLayer(nn.Module):
"""This corresponds to the Block class in the original implementation."""
@@ -441,7 +419,7 @@ class Dinov2WithRegistersLayer(nn.Module):
super().__init__()
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = DINOV2_WITH_REGISTERS_ATTENTION_CLASSES[config._attn_implementation](config)
self.attention = Dinov2WithRegistersAttention(config)
self.layer_scale1 = Dinov2WithRegistersLayerScale(config)
self.drop_path = (
Dinov2WithRegistersDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
@@ -550,6 +528,7 @@ class Dinov2WithRegistersPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Dinov2WithRegistersSwiGLUFFN"]
_supports_sdpa = True
_supports_flash_attn_2 = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""

View File

@@ -282,5 +282,9 @@ class DPTConfig(PretrainedConfig):
output["model_type"] = self.__class__.model_type
return output
@property
def sub_configs(self):
return {"backbone_config": type(self.backbone_config)} if self.backbone_config is not None else {}
__all__ = ["DPTConfig"]

View File

@@ -20,9 +20,8 @@ https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_hea
"""
import collections.abc
import math
from dataclasses import dataclass
from typing import List, Optional, Set, Tuple, Union
from typing import Callable, List, Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -37,7 +36,7 @@ from ...file_utils import (
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ModelOutput, logging, torch_int
from ...utils.backbone_utils import load_backbone
@@ -295,8 +294,39 @@ class DPTViTPatchEmbeddings(nn.Module):
return embeddings
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
# Normalize the attention scores to probabilities.
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
# Mask heads if we want to
if attention_mask is not None:
attn_weights = attn_weights * attention_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DPT
class DPTViTSelfAttention(nn.Module):
class DPTSelfAttention(nn.Module):
def __init__(self, config: DPTConfig) -> None:
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
@@ -305,16 +335,18 @@ class DPTViTSelfAttention(nn.Module):
f"heads {config.num_attention_heads}."
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dropout_prob = config.attention_probs_dropout_prob
self.scaling = self.attention_head_size**-0.5
self.is_causal = False
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
@@ -323,33 +355,33 @@ class DPTViTSelfAttention(nn.Module):
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
query_layer = self.transpose_for_scores(self.query(hidden_states))
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
context_layer, attention_probs = attention_interface(
self,
query_layer,
key_layer,
value_layer,
head_mask,
is_causal=self.is_causal,
scaling=self.scaling,
dropout=0.0 if not self.training else self.dropout_prob,
)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
context_layer = context_layer.reshape(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
@@ -378,7 +410,7 @@ class DPTViTSelfOutput(nn.Module):
class DPTViTAttention(nn.Module):
def __init__(self, config: DPTConfig) -> None:
super().__init__()
self.attention = DPTViTSelfAttention(config)
self.attention = DPTSelfAttention(config)
self.output = DPTViTSelfOutput(config)
self.pruned_heads = set()
@@ -809,6 +841,8 @@ class DPTPreTrainedModel(PreTrainedModel):
base_model_prefix = "dpt"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_supports_sdpa = True
_supports_flash_attn_2 = True
def _init_weights(self, module):
"""Initialize the weights"""

View File

@@ -5,8 +5,7 @@
# modular_ijepa.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import collections.abc
import math
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import torch
import torch.nn as nn
@@ -14,7 +13,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
@@ -167,6 +166,7 @@ class IJepaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["IJepaEmbeddings", "IJepaLayer"]
_supports_sdpa = True
_supports_flash_attn_2 = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
@@ -189,6 +189,36 @@ class IJepaPreTrainedModel(PreTrainedModel):
).to(module.position_embeddings.dtype)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
# Normalize the attention scores to probabilities.
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
# Mask heads if we want to
if attention_mask is not None:
attn_weights = attn_weights * attention_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class IJepaSelfAttention(nn.Module):
def __init__(self, config: IJepaConfig) -> None:
super().__init__()
@@ -198,16 +228,18 @@ class IJepaSelfAttention(nn.Module):
f"heads {config.num_attention_heads}."
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dropout_prob = config.attention_probs_dropout_prob
self.scaling = self.attention_head_size**-0.5
self.is_causal = False
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
@@ -216,84 +248,37 @@ class IJepaSelfAttention(nn.Module):
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
query_layer = self.transpose_for_scores(self.query(hidden_states))
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class IJepaSdpaSelfAttention(IJepaSelfAttention):
def __init__(self, config: IJepaConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self,
hidden_states: torch.FloatTensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions or head_mask is not None:
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`IJepaSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
head_mask=head_mask,
output_attentions=output_attentions,
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
context_layer, attention_probs = attention_interface(
self,
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
is_causal=self.is_causal,
scaling=self.scaling,
dropout=0.0 if not self.training else self.dropout_prob,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
context_layer = context_layer.reshape(new_context_layer_shape)
return context_layer, None
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class IJepaSelfOutput(nn.Module):
@@ -353,12 +338,6 @@ class IJepaAttention(nn.Module):
return outputs
class IJepaSdpaAttention(IJepaAttention):
def __init__(self, config: IJepaConfig) -> None:
super().__init__(config)
self.attention = IJepaSdpaSelfAttention(config)
class IJepaIntermediate(nn.Module):
def __init__(self, config: IJepaConfig) -> None:
super().__init__()
@@ -390,12 +369,6 @@ class IJepaOutput(nn.Module):
return hidden_states
IJEPA_ATTENTION_CLASSES = {
"eager": IJepaAttention,
"sdpa": IJepaSdpaAttention,
}
class IJepaLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
@@ -403,7 +376,7 @@ class IJepaLayer(nn.Module):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = IJEPA_ATTENTION_CLASSES[config._attn_implementation](config)
self.attention = IJepaAttention(config)
self.intermediate = IJepaIntermediate(config)
self.output = IJepaOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -531,7 +504,6 @@ IJEPA_INPUTS_DOCSTRING = r"""
_EXPECTED_OUTPUT_SHAPE = [1, 256, 1280]
IJEPA_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and

View File

@@ -108,6 +108,7 @@ class IJepaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["IJepaEmbeddings", "IJepaLayer"]
_supports_sdpa = True
_supports_flash_attn_2 = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""

View File

@@ -15,10 +15,9 @@
"""PyTorch VideoMAE (masked autoencoder) model."""
import collections.abc
import math
from copy import deepcopy
from dataclasses import dataclass
from typing import Optional, Set, Tuple, Union
from typing import Callable, Optional, Set, Tuple, Union
import numpy as np
import torch
@@ -28,7 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
@@ -196,6 +195,37 @@ class VideoMAEPatchEmbeddings(nn.Module):
return embeddings
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
# Normalize the attention scores to probabilities.
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
# Mask heads if we want to
if attention_mask is not None:
attn_weights = attn_weights * attention_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class VideoMAESelfAttention(nn.Module):
def __init__(self, config: VideoMAEConfig) -> None:
super().__init__()
@@ -204,10 +234,13 @@ class VideoMAESelfAttention(nn.Module):
f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dropout_prob = config.attention_probs_dropout_prob
self.scaling = self.attention_head_size**-0.5
self.is_causal = False
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
@@ -220,8 +253,6 @@ class VideoMAESelfAttention(nn.Module):
self.q_bias = None
self.v_bias = None
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
@@ -239,65 +270,33 @@ class VideoMAESelfAttention(nn.Module):
value_layer = self.transpose_for_scores(values)
query_layer = self.transpose_for_scores(queries)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class VideoMAESdpaSelfAttention(VideoMAESelfAttention):
def __init__(self, config: VideoMAEConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
k_bias = torch.zeros_like(self.v_bias, requires_grad=False) if self.q_bias is not None else None
keys = nn.functional.linear(input=hidden_states, weight=self.key.weight, bias=k_bias)
values = nn.functional.linear(input=hidden_states, weight=self.value.weight, bias=self.v_bias)
queries = nn.functional.linear(input=hidden_states, weight=self.query.weight, bias=self.q_bias)
key_layer = self.transpose_for_scores(keys)
value_layer = self.transpose_for_scores(values)
query_layer = self.transpose_for_scores(queries)
context_layer = torch.nn.functional.scaled_dot_product_attention(
context_layer, attention_probs = attention_interface(
self,
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
is_causal=self.is_causal,
scaling=self.scaling,
dropout=0.0 if not self.training else self.dropout_prob,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
context_layer = context_layer.reshape(new_context_layer_shape)
return context_layer, None
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VideoMAE
@@ -359,13 +358,6 @@ class VideoMAEAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->VideoMAE
class VideoMAESdpaAttention(VideoMAEAttention):
def __init__(self, config: VideoMAEConfig) -> None:
super().__init__(config)
self.attention = VideoMAESdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->VideoMAE
class VideoMAEIntermediate(nn.Module):
def __init__(self, config: VideoMAEConfig) -> None:
@@ -399,9 +391,6 @@ class VideoMAEOutput(nn.Module):
return hidden_states
VIDEOMAE_ATTENTION_CLASSES = {"eager": VideoMAEAttention, "sdpa": VideoMAESdpaAttention}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->VideoMAE,VIT->VIDEOMAE
class VideoMAELayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
@@ -410,7 +399,7 @@ class VideoMAELayer(nn.Module):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = VIDEOMAE_ATTENTION_CLASSES[config._attn_implementation](config)
self.attention = VideoMAEAttention(config)
self.intermediate = VideoMAEIntermediate(config)
self.output = VideoMAEOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -508,6 +497,7 @@ class VideoMAEPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_supports_sdpa = True
_supports_flash_attn_2 = True
def _init_weights(self, module):
"""Initialize the weights"""

View File

@@ -16,7 +16,7 @@
import collections.abc
import math
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -30,7 +30,7 @@ from ...modeling_outputs import (
ImageClassifierOutput,
MaskedImageModelingOutput,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
@@ -184,6 +184,36 @@ class ViTPatchEmbeddings(nn.Module):
return embeddings
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
# Normalize the attention scores to probabilities.
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
# Mask heads if we want to
if attention_mask is not None:
attn_weights = attn_weights * attention_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class ViTSelfAttention(nn.Module):
def __init__(self, config: ViTConfig) -> None:
super().__init__()
@@ -193,16 +223,18 @@ class ViTSelfAttention(nn.Module):
f"heads {config.num_attention_heads}."
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dropout_prob = config.attention_probs_dropout_prob
self.scaling = self.attention_head_size**-0.5
self.is_causal = False
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
@@ -211,84 +243,37 @@ class ViTSelfAttention(nn.Module):
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
query_layer = self.transpose_for_scores(self.query(hidden_states))
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class ViTSdpaSelfAttention(ViTSelfAttention):
def __init__(self, config: ViTConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self,
hidden_states: torch.FloatTensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions or head_mask is not None:
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`ViTSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
head_mask=head_mask,
output_attentions=output_attentions,
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
context_layer, attention_probs = attention_interface(
self,
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
is_causal=self.is_causal,
scaling=self.scaling,
dropout=0.0 if not self.training else self.dropout_prob,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
context_layer = context_layer.reshape(new_context_layer_shape)
return context_layer, None
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class ViTSelfOutput(nn.Module):
@@ -348,12 +333,6 @@ class ViTAttention(nn.Module):
return outputs
class ViTSdpaAttention(ViTAttention):
def __init__(self, config: ViTConfig) -> None:
super().__init__(config)
self.attention = ViTSdpaSelfAttention(config)
class ViTIntermediate(nn.Module):
def __init__(self, config: ViTConfig) -> None:
super().__init__()
@@ -385,12 +364,6 @@ class ViTOutput(nn.Module):
return hidden_states
VIT_ATTENTION_CLASSES = {
"eager": ViTAttention,
"sdpa": ViTSdpaAttention,
}
class ViTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
@@ -398,7 +371,7 @@ class ViTLayer(nn.Module):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = VIT_ATTENTION_CLASSES[config._attn_implementation](config)
self.attention = ViTAttention(config)
self.intermediate = ViTIntermediate(config)
self.output = ViTOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -496,6 +469,7 @@ class ViTPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["ViTEmbeddings", "ViTLayer"]
_supports_sdpa = True
_supports_flash_attn_2 = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""

View File

@@ -15,10 +15,9 @@
"""PyTorch ViT MAE (masked autoencoder) model."""
import collections.abc
import math
from copy import deepcopy
from dataclasses import dataclass
from typing import Optional, Set, Tuple, Union
from typing import Callable, Optional, Set, Tuple, Union
import numpy as np
import torch
@@ -27,7 +26,7 @@ from torch import nn
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
@@ -356,6 +355,37 @@ class ViTMAEPatchEmbeddings(nn.Module):
return x
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
# Normalize the attention scores to probabilities.
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
# Mask heads if we want to
if attention_mask is not None:
attn_weights = attn_weights * attention_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention ViT->ViTMAE
class ViTMAESelfAttention(nn.Module):
def __init__(self, config: ViTMAEConfig) -> None:
@@ -366,16 +396,18 @@ class ViTMAESelfAttention(nn.Module):
f"heads {config.num_attention_heads}."
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dropout_prob = config.attention_probs_dropout_prob
self.scaling = self.attention_head_size**-0.5
self.is_causal = False
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
@@ -384,85 +416,37 @@ class ViTMAESelfAttention(nn.Module):
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
query_layer = self.transpose_for_scores(self.query(hidden_states))
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention ViT->ViTMAE
class ViTMAESdpaSelfAttention(ViTMAESelfAttention):
def __init__(self, config: ViTMAEConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self,
hidden_states: torch.FloatTensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions or head_mask is not None:
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`ViTMAESdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
head_mask=head_mask,
output_attentions=output_attentions,
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
context_layer, attention_probs = attention_interface(
self,
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
is_causal=self.is_causal,
scaling=self.scaling,
dropout=0.0 if not self.training else self.dropout_prob,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
context_layer = context_layer.reshape(new_context_layer_shape)
return context_layer, None
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMAE
@@ -524,13 +508,6 @@ class ViTMAEAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTMAE
class ViTMAESdpaAttention(ViTMAEAttention):
def __init__(self, config: ViTMAEConfig) -> None:
super().__init__(config)
self.attention = ViTMAESdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->ViTMAE
class ViTMAEIntermediate(nn.Module):
def __init__(self, config: ViTMAEConfig) -> None:
@@ -564,12 +541,6 @@ class ViTMAEOutput(nn.Module):
return hidden_states
VITMAE_ATTENTION_CLASSES = {
"eager": ViTMAEAttention,
"sdpa": ViTMAESdpaAttention,
}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE,VIT->VITMAE
class ViTMAELayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
@@ -578,7 +549,7 @@ class ViTMAELayer(nn.Module):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = VITMAE_ATTENTION_CLASSES[config._attn_implementation](config)
self.attention = ViTMAEAttention(config)
self.intermediate = ViTMAEIntermediate(config)
self.output = ViTMAEOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -676,6 +647,7 @@ class ViTMAEPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_supports_sdpa = True
_supports_flash_attn_2 = True
def _init_weights(self, module):
"""Initialize the weights"""

View File

@@ -15,8 +15,7 @@
"""PyTorch ViT MSN (masked siamese network) model."""
import collections.abc
import math
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -25,7 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_start_docstrings,
@@ -173,6 +172,37 @@ class ViTMSNPatchEmbeddings(nn.Module):
return embeddings
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
# Normalize the attention scores to probabilities.
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
# Mask heads if we want to
if attention_mask is not None:
attn_weights = attn_weights * attention_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->ViTMSN
class ViTMSNSelfAttention(nn.Module):
def __init__(self, config: ViTMSNConfig) -> None:
@@ -183,16 +213,18 @@ class ViTMSNSelfAttention(nn.Module):
f"heads {config.num_attention_heads}."
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dropout_prob = config.attention_probs_dropout_prob
self.scaling = self.attention_head_size**-0.5
self.is_causal = False
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
@@ -201,85 +233,37 @@ class ViTMSNSelfAttention(nn.Module):
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
query_layer = self.transpose_for_scores(self.query(hidden_states))
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->ViTMSN
class ViTMSNSdpaSelfAttention(ViTMSNSelfAttention):
def __init__(self, config: ViTMSNConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self,
hidden_states: torch.FloatTensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions or head_mask is not None:
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`ViTMSNSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
head_mask=head_mask,
output_attentions=output_attentions,
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
context_layer, attention_probs = attention_interface(
self,
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
is_causal=self.is_causal,
scaling=self.scaling,
dropout=0.0 if not self.training else self.dropout_prob,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
context_layer = context_layer.reshape(new_context_layer_shape)
return context_layer, None
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMSN
@@ -341,13 +325,6 @@ class ViTMSNAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTMSN
class ViTMSNSdpaAttention(ViTMSNAttention):
def __init__(self, config: ViTMSNConfig) -> None:
super().__init__(config)
self.attention = ViTMSNSdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTMSN
class ViTMSNIntermediate(nn.Module):
def __init__(self, config: ViTMSNConfig) -> None:
@@ -381,9 +358,6 @@ class ViTMSNOutput(nn.Module):
return hidden_states
VITMSN_ATTENTION_CLASSES = {"eager": ViTMSNAttention, "sdpa": ViTMSNSdpaAttention}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMSN, VIT->VITMSN
class ViTMSNLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
@@ -392,7 +366,7 @@ class ViTMSNLayer(nn.Module):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = VITMSN_ATTENTION_CLASSES[config._attn_implementation](config)
self.attention = ViTMSNAttention(config)
self.intermediate = ViTMSNIntermediate(config)
self.output = ViTMSNOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -491,6 +465,7 @@ class ViTMSNPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["ViTMSNAttention", "ViTMSNSdpaAttention"]
_supports_sdpa = True
_supports_flash_attn_2 = True
# todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211
# when creating pre-training scripts.

View File

@@ -20,8 +20,7 @@ This code is the same as the original Vision Transformer (ViT) with 2 modificati
"""
import collections.abc
import math
from typing import Optional, Set, Tuple, Union
from typing import Callable, Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -29,7 +28,7 @@ from torch import nn
from ...activations import ACT2FN
from ...modeling_outputs import BackboneOutput, BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_start_docstrings,
@@ -103,6 +102,37 @@ class VitPoseBackboneEmbeddings(nn.Module):
return embeddings
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
# Normalize the attention scores to probabilities.
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
# Mask heads if we want to
if attention_mask is not None:
attn_weights = attn_weights * attention_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->VitPoseBackbone
class VitPoseBackboneSelfAttention(nn.Module):
def __init__(self, config: VitPoseBackboneConfig) -> None:
@@ -113,16 +143,18 @@ class VitPoseBackboneSelfAttention(nn.Module):
f"heads {config.num_attention_heads}."
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dropout_prob = config.attention_probs_dropout_prob
self.scaling = self.attention_head_size**-0.5
self.is_causal = False
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
@@ -131,33 +163,33 @@ class VitPoseBackboneSelfAttention(nn.Module):
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
query_layer = self.transpose_for_scores(self.query(hidden_states))
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
context_layer, attention_probs = attention_interface(
self,
query_layer,
key_layer,
value_layer,
head_mask,
is_causal=self.is_causal,
scaling=self.scaling,
dropout=0.0 if not self.training else self.dropout_prob,
)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
context_layer = context_layer.reshape(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
@@ -392,6 +424,8 @@ class VitPoseBackbonePreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["VitPoseBackboneEmbeddings", "VitPoseBackboneLayer"]
_supports_sdpa = True
_supports_flash_attn_2 = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm, VitPoseBackboneEmbeddings]) -> None:
"""Initialize the weights"""

View File

@@ -14,8 +14,7 @@
# limitations under the License.
"""PyTorch ViViT model."""
import math
from typing import Optional, Set, Tuple, Union
from typing import Callable, Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -24,7 +23,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_start_docstrings,
@@ -166,6 +165,37 @@ class VivitEmbeddings(nn.Module):
return embeddings
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
# Normalize the attention scores to probabilities.
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
# Mask heads if we want to
if attention_mask is not None:
attn_weights = attn_weights * attention_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Vivit
class VivitSelfAttention(nn.Module):
def __init__(self, config: VivitConfig) -> None:
@@ -176,16 +206,18 @@ class VivitSelfAttention(nn.Module):
f"heads {config.num_attention_heads}."
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dropout_prob = config.attention_probs_dropout_prob
self.scaling = self.attention_head_size**-0.5
self.is_causal = False
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
@@ -194,82 +226,37 @@ class VivitSelfAttention(nn.Module):
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
query_layer = self.transpose_for_scores(self.query(hidden_states))
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Adapted from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Vivit
class VivitSdpaSelfAttention(VivitSelfAttention):
def __init__(self, config: VivitConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions or head_mask is not None:
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"VivitSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support"
" `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but specifying"
" the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be"
' removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
head_mask,
output_attentions,
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
context_layer, attention_probs = attention_interface(
self,
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
is_causal=self.is_causal,
scaling=self.scaling,
dropout=0.0 if not self.training else self.dropout_prob,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
context_layer = context_layer.reshape(new_context_layer_shape)
return context_layer, None
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Vivit
@@ -331,13 +318,6 @@ class VivitAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Vivit
class VivitSdpaAttention(VivitAttention):
def __init__(self, config: VivitConfig) -> None:
super().__init__(config)
self.attention = VivitSdpaSelfAttention(config)
class VivitIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
@@ -372,12 +352,6 @@ class VivitOutput(nn.Module):
return hidden_states
VIVIT_ATTENTION_CLASSES = {
"eager": VivitAttention,
"sdpa": VivitSdpaAttention,
}
class VivitLayer(nn.Module):
"""This corresponds to the EncoderBlock class in the scenic/vivit implementation."""
@@ -385,7 +359,7 @@ class VivitLayer(nn.Module):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = VIVIT_ATTENTION_CLASSES[config._attn_implementation](config)
self.attention = VivitAttention(config)
self.intermediate = VivitIntermediate(config)
self.output = VivitOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -495,6 +469,7 @@ class VivitPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = []
_supports_sdpa = True
_supports_flash_attn_2 = True
def _init_weights(self, module):
"""Initialize the weights"""

View File

@@ -15,9 +15,8 @@
"""PyTorch YOLOS model."""
import collections.abc
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -25,7 +24,7 @@ from torch import nn
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
@@ -231,6 +230,37 @@ class YolosPatchEmbeddings(nn.Module):
return embeddings
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
# Normalize the attention scores to probabilities.
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
# Mask heads if we want to
if attention_mask is not None:
attn_weights = attn_weights * attention_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Yolos
class YolosSelfAttention(nn.Module):
def __init__(self, config: YolosConfig) -> None:
@@ -241,16 +271,18 @@ class YolosSelfAttention(nn.Module):
f"heads {config.num_attention_heads}."
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dropout_prob = config.attention_probs_dropout_prob
self.scaling = self.attention_head_size**-0.5
self.is_causal = False
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
@@ -259,85 +291,37 @@ class YolosSelfAttention(nn.Module):
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
query_layer = self.transpose_for_scores(self.query(hidden_states))
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Yolos
class YolosSdpaSelfAttention(YolosSelfAttention):
def __init__(self, config: YolosConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self,
hidden_states: torch.FloatTensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions or head_mask is not None:
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`YolosSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
head_mask=head_mask,
output_attentions=output_attentions,
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
context_layer, attention_probs = attention_interface(
self,
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
is_causal=self.is_causal,
scaling=self.scaling,
dropout=0.0 if not self.training else self.dropout_prob,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
context_layer = context_layer.reshape(new_context_layer_shape)
return context_layer, None
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Yolos
@@ -399,13 +383,6 @@ class YolosAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Yolos
class YolosSdpaAttention(YolosAttention):
def __init__(self, config: YolosConfig) -> None:
super().__init__(config)
self.attention = YolosSdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Yolos
class YolosIntermediate(nn.Module):
def __init__(self, config: YolosConfig) -> None:
@@ -439,9 +416,6 @@ class YolosOutput(nn.Module):
return hidden_states
YOLOS_ATTENTION_CLASSES = {"eager": YolosAttention, "sdpa": YolosSdpaAttention}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->Yolos,VIT->YOLOS
class YolosLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
@@ -450,7 +424,7 @@ class YolosLayer(nn.Module):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = YOLOS_ATTENTION_CLASSES[config._attn_implementation](config)
self.attention = YolosAttention(config)
self.intermediate = YolosIntermediate(config)
self.output = YolosOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -575,6 +549,7 @@ class YolosPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = []
_supports_sdpa = True
_supports_flash_attn_2 = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""

View File

@@ -1219,7 +1219,8 @@ class ZoeDepthMetricDepthEstimationHead(nn.Module):
return out, None
# Copied from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->ZoeDepth,dpt->zoedepth
# Modified from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->ZoeDepth,dpt->zoedepth
# avoiding sdpa and flash_attn_2 support, it's done int the backend
class ZoeDepthPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained

View File

@@ -255,6 +255,10 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Inductor error for dynamic shape")
def test_sdpa_can_compile_dynamic(self):
pass
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@@ -15,14 +15,24 @@
"""Testing suite for the PyTorch VideoMAE model."""
import copy
import tempfile
import unittest
import numpy as np
from huggingface_hub import hf_hub_download
from pytest import mark
from transformers import VideoMAEConfig
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.testing_utils import (
is_flaky,
require_flash_attn,
require_torch,
require_torch_gpu,
require_vision,
slow,
torch_device,
)
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@@ -338,6 +348,59 @@ class VideoMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
check_hidden_states_output(inputs_dict, config, model_class)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
@is_flaky()
def test_flash_attn_2_inference_equivalence(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
inputs_dict["pixel_values"] = inputs_dict["pixel_values"].to(torch.bfloat16)
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
outputs = model(**inputs_dict, output_hidden_states=True)
outputs_fa = model_fa(**inputs_dict, output_hidden_states=True)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
# check with inference + dropout
model.train()
_ = model_fa(**inputs_dict)
@unittest.skip("Not applicable for VideoMAE")
def test_flash_attn_2_inference_equivalence_right_padding(self):
pass
# We will verify our results on a video of eating spaghetti
# Frame indices used: [164 168 172 176 181 185 189 193 198 202 206 210 215 219 223 227]

View File

@@ -19,9 +19,18 @@ import tempfile
import unittest
import numpy as np
from pytest import mark
from transformers import ViTMAEConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.testing_utils import (
is_flaky,
require_flash_attn,
require_torch,
require_torch_gpu,
require_vision,
slow,
torch_device,
)
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@@ -269,6 +278,63 @@ class ViTMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
model = ViTMAEModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
@is_flaky()
def test_flash_attn_2_inference_equivalence(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
inputs_dict["pixel_values"] = inputs_dict["pixel_values"].to(torch.bfloat16)
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
# ForPretraining model has random `noise` -> need to set seed
# to make the test deterministic
torch.manual_seed(12345)
outputs = model(**inputs_dict, output_hidden_states=True)
torch.manual_seed(12345)
outputs_fa = model_fa(**inputs_dict, output_hidden_states=True)
logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs.decoder_hidden_states[-1]
)
logits_fa = (
outputs_fa.hidden_states[-1]
if not model.config.is_encoder_decoder
else outputs_fa.decoder_hidden_states[-1]
)
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
# check with inference + dropout
model.train()
_ = model_fa(**inputs_dict)
@unittest.skip("Not applicable for VideoMAE")
def test_flash_attn_2_inference_equivalence_right_padding(self):
pass
# We will verify our results on an image of cute cats
def prepare_img():

View File

@@ -130,7 +130,7 @@ class ConfigTester:
general_config_dict = config.to_dict()
# Iterate over all sub_configs if there are any and load them with their own classes
sub_configs = self.config_class.sub_configs
sub_configs = general_config_loaded.sub_configs
for sub_config_key, sub_class in sub_configs.items():
if sub_class.__name__ == "AutoConfig":
sub_class = sub_class.for_model(**general_config_dict[sub_config_key]).__class__

View File

@@ -315,8 +315,6 @@ class ModelTesterMixin:
return inputs_dict
def test_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def check_save_load(out1, out2):
# make sure we don't have nans
out_2 = out2.cpu().numpy()
@@ -330,6 +328,7 @@ class ModelTesterMixin:
self.assertLessEqual(max_diff, 1e-5)
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
model.to(torch_device)
model.eval()
@@ -508,6 +507,7 @@ class ModelTesterMixin:
@is_flaky(description="low likelihood of failure, reason not yet discovered")
def test_save_load_fast_init_from_base(self):
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING:
self.skipTest(reason=f"{config.__class__.__name__} not in MODEL_MAPPING")
@@ -517,7 +517,6 @@ class ModelTesterMixin:
if isinstance(base_class, tuple):
base_class = base_class[0]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
@@ -2228,9 +2227,9 @@ class ModelTesterMixin:
def test_correct_missing_keys(self):
if not self.test_missing_keys:
self.skipTest(reason="test_missing_keys is set to `False`")
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
base_model_prefix = model.base_model_prefix
@@ -2287,8 +2286,8 @@ class ModelTesterMixin:
@require_safetensors
def test_can_use_safetensors(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model_tied = model_class(config)
with tempfile.TemporaryDirectory() as d:
try:
@@ -2323,9 +2322,9 @@ class ModelTesterMixin:
)
def test_load_save_without_tied_weights(self):
for model_class in self.all_model_classes:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config.tie_word_embeddings = False
for model_class in self.all_model_classes:
model = model_class(config)
with tempfile.TemporaryDirectory() as d:
model.save_pretrained(d)
@@ -2373,8 +2372,8 @@ class ModelTesterMixin:
)
def test_model_weights_reload_no_missing_tied_weights(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)