Refactor Attention implementation for ViT-based models (#36545)
* Refactor vit attention * Refactor ViT-based models * 🚨🚨🚨 Fix prefix for DPT * Update params order * trigger tests * Fix Dinov2 attention * Fix DPT attention impl propagation for backbone config * Common test fix: config is modif. inplace - avoid it * view->reshape * Fixup * Fixup * Enable IJepa FA2 * Add FA2 in corresponding model docs
This commit is contained in:
committed by
GitHub
parent
730d2a52e7
commit
66291778dd
@@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ rendered properly in your Markdown viewer.
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ specific language governing permissions and limitations under the License.
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
|
||||
@@ -18,6 +18,8 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
@@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ rendered properly in your Markdown viewer.
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ rendered properly in your Markdown viewer.
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
|
||||
@@ -2098,6 +2098,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if not isinstance(requested_attn_implementation, dict)
|
||||
else requested_attn_implementation.get(key, None)
|
||||
)
|
||||
# For models with backbone sub-config might be not initialized
|
||||
if sub_config is not None:
|
||||
sub_config._attn_implementation_internal = curr_attn_implementation
|
||||
|
||||
if use_flash_attention_2:
|
||||
|
||||
@@ -14,8 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch Audio Spectrogram Transformer (AST) model."""
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@@ -24,7 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||
from .configuration_audio_spectrogram_transformer import ASTConfig
|
||||
@@ -108,6 +107,37 @@ class ASTPatchEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
# Mask heads if we want to
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights * attention_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->AST
|
||||
class ASTSelfAttention(nn.Module):
|
||||
def __init__(self, config: ASTConfig) -> None:
|
||||
@@ -118,16 +148,18 @@ class ASTSelfAttention(nn.Module):
|
||||
f"heads {config.num_attention_heads}."
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
self.is_causal = False
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
@@ -136,85 +168,37 @@ class ASTSelfAttention(nn.Module):
|
||||
def forward(
|
||||
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->AST
|
||||
class ASTSdpaSelfAttention(ASTSelfAttention):
|
||||
def __init__(self, config: ASTConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
if output_attentions or head_mask is not None:
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`ASTSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
|
||||
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
|
||||
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||
context_layer, attention_probs = attention_interface(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
head_mask,
|
||||
self.attention_probs_dropout_prob if self.training else 0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.dropout_prob,
|
||||
)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
return context_layer, None
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->AST
|
||||
@@ -276,13 +260,6 @@ class ASTAttention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->AST
|
||||
class ASTSdpaAttention(ASTAttention):
|
||||
def __init__(self, config: ASTConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.attention = ASTSdpaSelfAttention(config)
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->AST
|
||||
class ASTIntermediate(nn.Module):
|
||||
def __init__(self, config: ASTConfig) -> None:
|
||||
@@ -316,12 +293,6 @@ class ASTOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
AST_ATTENTION_CLASSES = {
|
||||
"eager": ASTAttention,
|
||||
"sdpa": ASTSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST,VIT->AST
|
||||
class ASTLayer(nn.Module):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
@@ -330,7 +301,7 @@ class ASTLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = AST_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.attention = ASTAttention(config)
|
||||
self.intermediate = ASTIntermediate(config)
|
||||
self.output = ASTOutput(config)
|
||||
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
@@ -428,6 +399,7 @@ class ASTPreTrainedModel(PreTrainedModel):
|
||||
main_input_name = "input_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
# Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
|
||||
@@ -15,9 +15,8 @@
|
||||
"""PyTorch DeiT model."""
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Set, Tuple, Union
|
||||
from typing import Callable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@@ -31,7 +30,7 @@ from ...modeling_outputs import (
|
||||
ImageClassifierOutput,
|
||||
MaskedImageModelingOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
@@ -180,6 +179,37 @@ class DeiTPatchEmbeddings(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
# Mask heads if we want to
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights * attention_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DeiT
|
||||
class DeiTSelfAttention(nn.Module):
|
||||
def __init__(self, config: DeiTConfig) -> None:
|
||||
@@ -190,16 +220,18 @@ class DeiTSelfAttention(nn.Module):
|
||||
f"heads {config.num_attention_heads}."
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
self.is_causal = False
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
@@ -208,85 +240,37 @@ class DeiTSelfAttention(nn.Module):
|
||||
def forward(
|
||||
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->DeiT
|
||||
class DeiTSdpaSelfAttention(DeiTSelfAttention):
|
||||
def __init__(self, config: DeiTConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
if output_attentions or head_mask is not None:
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`DeiTSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
|
||||
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
|
||||
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||
context_layer, attention_probs = attention_interface(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
head_mask,
|
||||
self.attention_probs_dropout_prob if self.training else 0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.dropout_prob,
|
||||
)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
return context_layer, None
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->DeiT
|
||||
@@ -348,13 +332,6 @@ class DeiTAttention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->DeiT
|
||||
class DeiTSdpaAttention(DeiTAttention):
|
||||
def __init__(self, config: DeiTConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.attention = DeiTSdpaSelfAttention(config)
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT
|
||||
class DeiTIntermediate(nn.Module):
|
||||
def __init__(self, config: DeiTConfig) -> None:
|
||||
@@ -388,12 +365,6 @@ class DeiTOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
DEIT_ATTENTION_CLASSES = {
|
||||
"eager": DeiTAttention,
|
||||
"sdpa": DeiTSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT,VIT->DEIT
|
||||
class DeiTLayer(nn.Module):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
@@ -402,7 +373,7 @@ class DeiTLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = DEIT_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.attention = DeiTAttention(config)
|
||||
self.intermediate = DeiTIntermediate(config)
|
||||
self.output = DeiTOutput(config)
|
||||
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
@@ -501,6 +472,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["DeiTLayer"]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -240,7 +240,8 @@ class DepthAnythingFeatureFusionStage(nn.Module):
|
||||
return fused_hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->DepthAnything,dpt->depth_anything
|
||||
# Modified from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->DepthAnything,dpt->depth_anything
|
||||
# avoiding sdpa and flash_attn_2 support, it's done in the backend
|
||||
class DepthAnythingPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
|
||||
@@ -15,8 +15,7 @@
|
||||
"""PyTorch DINOv2 model."""
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@@ -30,7 +29,7 @@ from ...modeling_outputs import (
|
||||
BaseModelOutputWithPooling,
|
||||
ImageClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
@@ -172,6 +171,37 @@ class Dinov2PatchEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
# Mask heads if we want to
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights * attention_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2
|
||||
class Dinov2SelfAttention(nn.Module):
|
||||
def __init__(self, config: Dinov2Config) -> None:
|
||||
@@ -182,16 +212,18 @@ class Dinov2SelfAttention(nn.Module):
|
||||
f"heads {config.num_attention_heads}."
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
self.is_causal = False
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
@@ -200,78 +232,37 @@ class Dinov2SelfAttention(nn.Module):
|
||||
def forward(
|
||||
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Dinov2SdpaSelfAttention(Dinov2SelfAttention):
|
||||
def __init__(self, config: Dinov2Config) -> None:
|
||||
super().__init__(config)
|
||||
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
|
||||
|
||||
def forward(
|
||||
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"Dinov2Model is using Dinov2SdpaSelfAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states, head_mask=head_mask, output_attentions=output_attentions
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||
context_layer, attention_probs = attention_interface(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
head_mask,
|
||||
self.attention_probs_dropout_prob if self.training else 0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.dropout_prob,
|
||||
)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
return context_layer, None
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
|
||||
@@ -333,13 +324,6 @@ class Dinov2Attention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Dinov2
|
||||
class Dinov2SdpaAttention(Dinov2Attention):
|
||||
def __init__(self, config: Dinov2Config) -> None:
|
||||
super().__init__(config)
|
||||
self.attention = Dinov2SdpaSelfAttention(config)
|
||||
|
||||
|
||||
class Dinov2LayerScale(nn.Module):
|
||||
def __init__(self, config) -> None:
|
||||
super().__init__()
|
||||
@@ -421,12 +405,6 @@ class Dinov2SwiGLUFFN(nn.Module):
|
||||
return self.weights_out(hidden)
|
||||
|
||||
|
||||
DINOV2_ATTENTION_CLASSES = {
|
||||
"eager": Dinov2Attention,
|
||||
"sdpa": Dinov2SdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class Dinov2Layer(nn.Module):
|
||||
"""This corresponds to the Block class in the original implementation."""
|
||||
|
||||
@@ -434,7 +412,7 @@ class Dinov2Layer(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.attention = DINOV2_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.attention = Dinov2Attention(config)
|
||||
self.layer_scale1 = Dinov2LayerScale(config)
|
||||
self.drop_path = Dinov2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
|
||||
|
||||
@@ -542,6 +520,7 @@ class Dinov2PreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Dinov2SwiGLUFFN"]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -21,8 +21,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -30,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
@@ -185,6 +184,36 @@ class Dinov2WithRegistersEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
# Mask heads if we want to
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights * attention_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Dinov2WithRegistersSelfAttention(nn.Module):
|
||||
def __init__(self, config: Dinov2WithRegistersConfig) -> None:
|
||||
super().__init__()
|
||||
@@ -194,16 +223,18 @@ class Dinov2WithRegistersSelfAttention(nn.Module):
|
||||
f"heads {config.num_attention_heads}."
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
self.is_causal = False
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
@@ -212,78 +243,37 @@ class Dinov2WithRegistersSelfAttention(nn.Module):
|
||||
def forward(
|
||||
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Dinov2WithRegistersSdpaSelfAttention(Dinov2WithRegistersSelfAttention):
|
||||
def __init__(self, config: Dinov2WithRegistersConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
|
||||
|
||||
def forward(
|
||||
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"Dinov2WithRegistersModel is using Dinov2WithRegistersSdpaSelfAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states, head_mask=head_mask, output_attentions=output_attentions
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||
context_layer, attention_probs = attention_interface(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
head_mask,
|
||||
self.attention_probs_dropout_prob if self.training else 0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.dropout_prob,
|
||||
)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
return context_layer, None
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Dinov2WithRegistersSelfOutput(nn.Module):
|
||||
@@ -343,12 +333,6 @@ class Dinov2WithRegistersAttention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class Dinov2WithRegistersSdpaAttention(Dinov2WithRegistersAttention):
|
||||
def __init__(self, config: Dinov2WithRegistersConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.attention = Dinov2WithRegistersSdpaSelfAttention(config)
|
||||
|
||||
|
||||
class Dinov2WithRegistersLayerScale(nn.Module):
|
||||
def __init__(self, config) -> None:
|
||||
super().__init__()
|
||||
@@ -428,12 +412,6 @@ class Dinov2WithRegistersSwiGLUFFN(nn.Module):
|
||||
return self.weights_out(hidden)
|
||||
|
||||
|
||||
DINOV2_WITH_REGISTERS_ATTENTION_CLASSES = {
|
||||
"eager": Dinov2WithRegistersAttention,
|
||||
"sdpa": Dinov2WithRegistersSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class Dinov2WithRegistersLayer(nn.Module):
|
||||
"""This corresponds to the Block class in the original implementation."""
|
||||
|
||||
@@ -441,7 +419,7 @@ class Dinov2WithRegistersLayer(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.attention = DINOV2_WITH_REGISTERS_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.attention = Dinov2WithRegistersAttention(config)
|
||||
self.layer_scale1 = Dinov2WithRegistersLayerScale(config)
|
||||
self.drop_path = (
|
||||
Dinov2WithRegistersDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
|
||||
@@ -550,6 +528,7 @@ class Dinov2WithRegistersPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Dinov2WithRegistersSwiGLUFFN"]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -282,5 +282,9 @@ class DPTConfig(PretrainedConfig):
|
||||
output["model_type"] = self.__class__.model_type
|
||||
return output
|
||||
|
||||
@property
|
||||
def sub_configs(self):
|
||||
return {"backbone_config": type(self.backbone_config)} if self.backbone_config is not None else {}
|
||||
|
||||
|
||||
__all__ = ["DPTConfig"]
|
||||
|
||||
@@ -20,9 +20,8 @@ https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_hea
|
||||
"""
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Set, Tuple, Union
|
||||
from typing import Callable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@@ -37,7 +36,7 @@ from ...file_utils import (
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import ModelOutput, logging, torch_int
|
||||
from ...utils.backbone_utils import load_backbone
|
||||
@@ -295,8 +294,39 @@ class DPTViTPatchEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
# Mask heads if we want to
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights * attention_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DPT
|
||||
class DPTViTSelfAttention(nn.Module):
|
||||
class DPTSelfAttention(nn.Module):
|
||||
def __init__(self, config: DPTConfig) -> None:
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
@@ -305,16 +335,18 @@ class DPTViTSelfAttention(nn.Module):
|
||||
f"heads {config.num_attention_heads}."
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
self.is_causal = False
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
@@ -323,33 +355,33 @@ class DPTViTSelfAttention(nn.Module):
|
||||
def forward(
|
||||
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
context_layer, attention_probs = attention_interface(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
head_mask,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.dropout_prob,
|
||||
)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
@@ -378,7 +410,7 @@ class DPTViTSelfOutput(nn.Module):
|
||||
class DPTViTAttention(nn.Module):
|
||||
def __init__(self, config: DPTConfig) -> None:
|
||||
super().__init__()
|
||||
self.attention = DPTViTSelfAttention(config)
|
||||
self.attention = DPTSelfAttention(config)
|
||||
self.output = DPTViTSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@@ -809,6 +841,8 @@ class DPTPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "dpt"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -5,8 +5,7 @@
|
||||
# modular_ijepa.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
import collections.abc
|
||||
import math
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -14,7 +13,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
@@ -167,6 +166,7 @@ class IJepaPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["IJepaEmbeddings", "IJepaLayer"]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
@@ -189,6 +189,36 @@ class IJepaPreTrainedModel(PreTrainedModel):
|
||||
).to(module.position_embeddings.dtype)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
# Mask heads if we want to
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights * attention_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class IJepaSelfAttention(nn.Module):
|
||||
def __init__(self, config: IJepaConfig) -> None:
|
||||
super().__init__()
|
||||
@@ -198,16 +228,18 @@ class IJepaSelfAttention(nn.Module):
|
||||
f"heads {config.num_attention_heads}."
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
self.is_causal = False
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
@@ -216,84 +248,37 @@ class IJepaSelfAttention(nn.Module):
|
||||
def forward(
|
||||
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class IJepaSdpaSelfAttention(IJepaSelfAttention):
|
||||
def __init__(self, config: IJepaConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
if output_attentions or head_mask is not None:
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`IJepaSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
|
||||
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
|
||||
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||
context_layer, attention_probs = attention_interface(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
head_mask,
|
||||
self.attention_probs_dropout_prob if self.training else 0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.dropout_prob,
|
||||
)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
return context_layer, None
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class IJepaSelfOutput(nn.Module):
|
||||
@@ -353,12 +338,6 @@ class IJepaAttention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class IJepaSdpaAttention(IJepaAttention):
|
||||
def __init__(self, config: IJepaConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.attention = IJepaSdpaSelfAttention(config)
|
||||
|
||||
|
||||
class IJepaIntermediate(nn.Module):
|
||||
def __init__(self, config: IJepaConfig) -> None:
|
||||
super().__init__()
|
||||
@@ -390,12 +369,6 @@ class IJepaOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
IJEPA_ATTENTION_CLASSES = {
|
||||
"eager": IJepaAttention,
|
||||
"sdpa": IJepaSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class IJepaLayer(nn.Module):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
|
||||
@@ -403,7 +376,7 @@ class IJepaLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = IJEPA_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.attention = IJepaAttention(config)
|
||||
self.intermediate = IJepaIntermediate(config)
|
||||
self.output = IJepaOutput(config)
|
||||
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
@@ -531,7 +504,6 @@ IJEPA_INPUTS_DOCSTRING = r"""
|
||||
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 256, 1280]
|
||||
|
||||
|
||||
IJEPA_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||||
|
||||
@@ -108,6 +108,7 @@ class IJepaPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["IJepaEmbeddings", "IJepaLayer"]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -15,10 +15,9 @@
|
||||
"""PyTorch VideoMAE (masked autoencoder) model."""
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Set, Tuple, Union
|
||||
from typing import Callable, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -28,7 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
@@ -196,6 +195,37 @@ class VideoMAEPatchEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
# Mask heads if we want to
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights * attention_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class VideoMAESelfAttention(nn.Module):
|
||||
def __init__(self, config: VideoMAEConfig) -> None:
|
||||
super().__init__()
|
||||
@@ -204,10 +234,13 @@ class VideoMAESelfAttention(nn.Module):
|
||||
f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
|
||||
f"heads {config.num_attention_heads}."
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
self.is_causal = False
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
|
||||
@@ -220,8 +253,6 @@ class VideoMAESelfAttention(nn.Module):
|
||||
self.q_bias = None
|
||||
self.v_bias = None
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
@@ -239,65 +270,33 @@ class VideoMAESelfAttention(nn.Module):
|
||||
value_layer = self.transpose_for_scores(values)
|
||||
query_layer = self.transpose_for_scores(queries)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class VideoMAESdpaSelfAttention(VideoMAESelfAttention):
|
||||
def __init__(self, config: VideoMAEConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
|
||||
|
||||
def forward(
|
||||
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
k_bias = torch.zeros_like(self.v_bias, requires_grad=False) if self.q_bias is not None else None
|
||||
keys = nn.functional.linear(input=hidden_states, weight=self.key.weight, bias=k_bias)
|
||||
values = nn.functional.linear(input=hidden_states, weight=self.value.weight, bias=self.v_bias)
|
||||
queries = nn.functional.linear(input=hidden_states, weight=self.query.weight, bias=self.q_bias)
|
||||
|
||||
key_layer = self.transpose_for_scores(keys)
|
||||
value_layer = self.transpose_for_scores(values)
|
||||
query_layer = self.transpose_for_scores(queries)
|
||||
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||
context_layer, attention_probs = attention_interface(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
head_mask,
|
||||
self.attention_probs_dropout_prob if self.training else 0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.dropout_prob,
|
||||
)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
return context_layer, None
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VideoMAE
|
||||
@@ -359,13 +358,6 @@ class VideoMAEAttention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->VideoMAE
|
||||
class VideoMAESdpaAttention(VideoMAEAttention):
|
||||
def __init__(self, config: VideoMAEConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.attention = VideoMAESdpaSelfAttention(config)
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->VideoMAE
|
||||
class VideoMAEIntermediate(nn.Module):
|
||||
def __init__(self, config: VideoMAEConfig) -> None:
|
||||
@@ -399,9 +391,6 @@ class VideoMAEOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
VIDEOMAE_ATTENTION_CLASSES = {"eager": VideoMAEAttention, "sdpa": VideoMAESdpaAttention}
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->VideoMAE,VIT->VIDEOMAE
|
||||
class VideoMAELayer(nn.Module):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
@@ -410,7 +399,7 @@ class VideoMAELayer(nn.Module):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = VIDEOMAE_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.attention = VideoMAEAttention(config)
|
||||
self.intermediate = VideoMAEIntermediate(config)
|
||||
self.output = VideoMAEOutput(config)
|
||||
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
@@ -508,6 +497,7 @@ class VideoMAEPreTrainedModel(PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@@ -30,7 +30,7 @@ from ...modeling_outputs import (
|
||||
ImageClassifierOutput,
|
||||
MaskedImageModelingOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
@@ -184,6 +184,36 @@ class ViTPatchEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
# Mask heads if we want to
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights * attention_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class ViTSelfAttention(nn.Module):
|
||||
def __init__(self, config: ViTConfig) -> None:
|
||||
super().__init__()
|
||||
@@ -193,16 +223,18 @@ class ViTSelfAttention(nn.Module):
|
||||
f"heads {config.num_attention_heads}."
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
self.is_causal = False
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
@@ -211,84 +243,37 @@ class ViTSelfAttention(nn.Module):
|
||||
def forward(
|
||||
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class ViTSdpaSelfAttention(ViTSelfAttention):
|
||||
def __init__(self, config: ViTConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
if output_attentions or head_mask is not None:
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`ViTSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
|
||||
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
|
||||
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||
context_layer, attention_probs = attention_interface(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
head_mask,
|
||||
self.attention_probs_dropout_prob if self.training else 0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.dropout_prob,
|
||||
)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
return context_layer, None
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class ViTSelfOutput(nn.Module):
|
||||
@@ -348,12 +333,6 @@ class ViTAttention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class ViTSdpaAttention(ViTAttention):
|
||||
def __init__(self, config: ViTConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.attention = ViTSdpaSelfAttention(config)
|
||||
|
||||
|
||||
class ViTIntermediate(nn.Module):
|
||||
def __init__(self, config: ViTConfig) -> None:
|
||||
super().__init__()
|
||||
@@ -385,12 +364,6 @@ class ViTOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
VIT_ATTENTION_CLASSES = {
|
||||
"eager": ViTAttention,
|
||||
"sdpa": ViTSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class ViTLayer(nn.Module):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
|
||||
@@ -398,7 +371,7 @@ class ViTLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = VIT_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.attention = ViTAttention(config)
|
||||
self.intermediate = ViTIntermediate(config)
|
||||
self.output = ViTOutput(config)
|
||||
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
@@ -496,6 +469,7 @@ class ViTPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["ViTEmbeddings", "ViTLayer"]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -15,10 +15,9 @@
|
||||
"""PyTorch ViT MAE (masked autoencoder) model."""
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Set, Tuple, Union
|
||||
from typing import Callable, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -27,7 +26,7 @@ from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
@@ -356,6 +355,37 @@ class ViTMAEPatchEmbeddings(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
# Mask heads if we want to
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights * attention_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention ViT->ViTMAE
|
||||
class ViTMAESelfAttention(nn.Module):
|
||||
def __init__(self, config: ViTMAEConfig) -> None:
|
||||
@@ -366,16 +396,18 @@ class ViTMAESelfAttention(nn.Module):
|
||||
f"heads {config.num_attention_heads}."
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
self.is_causal = False
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
@@ -384,85 +416,37 @@ class ViTMAESelfAttention(nn.Module):
|
||||
def forward(
|
||||
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention ViT->ViTMAE
|
||||
class ViTMAESdpaSelfAttention(ViTMAESelfAttention):
|
||||
def __init__(self, config: ViTMAEConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
if output_attentions or head_mask is not None:
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`ViTMAESdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
|
||||
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
|
||||
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||
context_layer, attention_probs = attention_interface(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
head_mask,
|
||||
self.attention_probs_dropout_prob if self.training else 0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.dropout_prob,
|
||||
)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
return context_layer, None
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMAE
|
||||
@@ -524,13 +508,6 @@ class ViTMAEAttention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTMAE
|
||||
class ViTMAESdpaAttention(ViTMAEAttention):
|
||||
def __init__(self, config: ViTMAEConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.attention = ViTMAESdpaSelfAttention(config)
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->ViTMAE
|
||||
class ViTMAEIntermediate(nn.Module):
|
||||
def __init__(self, config: ViTMAEConfig) -> None:
|
||||
@@ -564,12 +541,6 @@ class ViTMAEOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
VITMAE_ATTENTION_CLASSES = {
|
||||
"eager": ViTMAEAttention,
|
||||
"sdpa": ViTMAESdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE,VIT->VITMAE
|
||||
class ViTMAELayer(nn.Module):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
@@ -578,7 +549,7 @@ class ViTMAELayer(nn.Module):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = VITMAE_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.attention = ViTMAEAttention(config)
|
||||
self.intermediate = ViTMAEIntermediate(config)
|
||||
self.output = ViTMAEOutput(config)
|
||||
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
@@ -676,6 +647,7 @@ class ViTMAEPreTrainedModel(PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -15,8 +15,7 @@
|
||||
"""PyTorch ViT MSN (masked siamese network) model."""
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@@ -25,7 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
@@ -173,6 +172,37 @@ class ViTMSNPatchEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
# Mask heads if we want to
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights * attention_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->ViTMSN
|
||||
class ViTMSNSelfAttention(nn.Module):
|
||||
def __init__(self, config: ViTMSNConfig) -> None:
|
||||
@@ -183,16 +213,18 @@ class ViTMSNSelfAttention(nn.Module):
|
||||
f"heads {config.num_attention_heads}."
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
self.is_causal = False
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
@@ -201,85 +233,37 @@ class ViTMSNSelfAttention(nn.Module):
|
||||
def forward(
|
||||
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->ViTMSN
|
||||
class ViTMSNSdpaSelfAttention(ViTMSNSelfAttention):
|
||||
def __init__(self, config: ViTMSNConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
if output_attentions or head_mask is not None:
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`ViTMSNSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
|
||||
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
|
||||
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||
context_layer, attention_probs = attention_interface(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
head_mask,
|
||||
self.attention_probs_dropout_prob if self.training else 0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.dropout_prob,
|
||||
)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
return context_layer, None
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMSN
|
||||
@@ -341,13 +325,6 @@ class ViTMSNAttention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTMSN
|
||||
class ViTMSNSdpaAttention(ViTMSNAttention):
|
||||
def __init__(self, config: ViTMSNConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.attention = ViTMSNSdpaSelfAttention(config)
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTMSN
|
||||
class ViTMSNIntermediate(nn.Module):
|
||||
def __init__(self, config: ViTMSNConfig) -> None:
|
||||
@@ -381,9 +358,6 @@ class ViTMSNOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
VITMSN_ATTENTION_CLASSES = {"eager": ViTMSNAttention, "sdpa": ViTMSNSdpaAttention}
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMSN, VIT->VITMSN
|
||||
class ViTMSNLayer(nn.Module):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
@@ -392,7 +366,7 @@ class ViTMSNLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = VITMSN_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.attention = ViTMSNAttention(config)
|
||||
self.intermediate = ViTMSNIntermediate(config)
|
||||
self.output = ViTMSNOutput(config)
|
||||
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
@@ -491,6 +465,7 @@ class ViTMSNPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["ViTMSNAttention", "ViTMSNSdpaAttention"]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
# todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211
|
||||
# when creating pre-training scripts.
|
||||
|
||||
@@ -20,8 +20,7 @@ This code is the same as the original Vision Transformer (ViT) with 2 modificati
|
||||
"""
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
from typing import Optional, Set, Tuple, Union
|
||||
from typing import Callable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@@ -29,7 +28,7 @@ from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BackboneOutput, BaseModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
@@ -103,6 +102,37 @@ class VitPoseBackboneEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
# Mask heads if we want to
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights * attention_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->VitPoseBackbone
|
||||
class VitPoseBackboneSelfAttention(nn.Module):
|
||||
def __init__(self, config: VitPoseBackboneConfig) -> None:
|
||||
@@ -113,16 +143,18 @@ class VitPoseBackboneSelfAttention(nn.Module):
|
||||
f"heads {config.num_attention_heads}."
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
self.is_causal = False
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
@@ -131,33 +163,33 @@ class VitPoseBackboneSelfAttention(nn.Module):
|
||||
def forward(
|
||||
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
context_layer, attention_probs = attention_interface(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
head_mask,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.dropout_prob,
|
||||
)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
@@ -392,6 +424,8 @@ class VitPoseBackbonePreTrainedModel(PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["VitPoseBackboneEmbeddings", "VitPoseBackboneLayer"]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm, VitPoseBackboneEmbeddings]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -14,8 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch ViViT model."""
|
||||
|
||||
import math
|
||||
from typing import Optional, Set, Tuple, Union
|
||||
from typing import Callable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@@ -24,7 +23,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
@@ -166,6 +165,37 @@ class VivitEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
# Mask heads if we want to
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights * attention_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Vivit
|
||||
class VivitSelfAttention(nn.Module):
|
||||
def __init__(self, config: VivitConfig) -> None:
|
||||
@@ -176,16 +206,18 @@ class VivitSelfAttention(nn.Module):
|
||||
f"heads {config.num_attention_heads}."
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
self.is_causal = False
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
@@ -194,82 +226,37 @@ class VivitSelfAttention(nn.Module):
|
||||
def forward(
|
||||
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Adapted from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Vivit
|
||||
class VivitSdpaSelfAttention(VivitSelfAttention):
|
||||
def __init__(self, config: VivitConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
|
||||
|
||||
def forward(
|
||||
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
if output_attentions or head_mask is not None:
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"VivitSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support"
|
||||
" `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but specifying"
|
||||
" the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be"
|
||||
' removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states,
|
||||
head_mask,
|
||||
output_attentions,
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||
context_layer, attention_probs = attention_interface(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
head_mask,
|
||||
self.attention_probs_dropout_prob if self.training else 0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.dropout_prob,
|
||||
)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
return context_layer, None
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Vivit
|
||||
@@ -331,13 +318,6 @@ class VivitAttention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Vivit
|
||||
class VivitSdpaAttention(VivitAttention):
|
||||
def __init__(self, config: VivitConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.attention = VivitSdpaSelfAttention(config)
|
||||
|
||||
|
||||
class VivitIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
@@ -372,12 +352,6 @@ class VivitOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
VIVIT_ATTENTION_CLASSES = {
|
||||
"eager": VivitAttention,
|
||||
"sdpa": VivitSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class VivitLayer(nn.Module):
|
||||
"""This corresponds to the EncoderBlock class in the scenic/vivit implementation."""
|
||||
|
||||
@@ -385,7 +359,7 @@ class VivitLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = VIVIT_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.attention = VivitAttention(config)
|
||||
self.intermediate = VivitIntermediate(config)
|
||||
self.output = VivitOutput(config)
|
||||
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
@@ -495,6 +469,7 @@ class VivitPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = []
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -15,9 +15,8 @@
|
||||
"""PyTorch YOLOS model."""
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@@ -25,7 +24,7 @@ from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
@@ -231,6 +230,37 @@ class YolosPatchEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
# Mask heads if we want to
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights * attention_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Yolos
|
||||
class YolosSelfAttention(nn.Module):
|
||||
def __init__(self, config: YolosConfig) -> None:
|
||||
@@ -241,16 +271,18 @@ class YolosSelfAttention(nn.Module):
|
||||
f"heads {config.num_attention_heads}."
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
self.is_causal = False
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
@@ -259,85 +291,37 @@ class YolosSelfAttention(nn.Module):
|
||||
def forward(
|
||||
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Yolos
|
||||
class YolosSdpaSelfAttention(YolosSelfAttention):
|
||||
def __init__(self, config: YolosConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
if output_attentions or head_mask is not None:
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`YolosSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
|
||||
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
|
||||
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||
context_layer, attention_probs = attention_interface(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
head_mask,
|
||||
self.attention_probs_dropout_prob if self.training else 0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
is_causal=self.is_causal,
|
||||
scaling=self.scaling,
|
||||
dropout=0.0 if not self.training else self.dropout_prob,
|
||||
)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
return context_layer, None
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Yolos
|
||||
@@ -399,13 +383,6 @@ class YolosAttention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Yolos
|
||||
class YolosSdpaAttention(YolosAttention):
|
||||
def __init__(self, config: YolosConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.attention = YolosSdpaSelfAttention(config)
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Yolos
|
||||
class YolosIntermediate(nn.Module):
|
||||
def __init__(self, config: YolosConfig) -> None:
|
||||
@@ -439,9 +416,6 @@ class YolosOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
YOLOS_ATTENTION_CLASSES = {"eager": YolosAttention, "sdpa": YolosSdpaAttention}
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->Yolos,VIT->YOLOS
|
||||
class YolosLayer(nn.Module):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
@@ -450,7 +424,7 @@ class YolosLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = YOLOS_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.attention = YolosAttention(config)
|
||||
self.intermediate = YolosIntermediate(config)
|
||||
self.output = YolosOutput(config)
|
||||
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
@@ -575,6 +549,7 @@ class YolosPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = []
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -1219,7 +1219,8 @@ class ZoeDepthMetricDepthEstimationHead(nn.Module):
|
||||
return out, None
|
||||
|
||||
|
||||
# Copied from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->ZoeDepth,dpt->zoedepth
|
||||
# Modified from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->ZoeDepth,dpt->zoedepth
|
||||
# avoiding sdpa and flash_attn_2 support, it's done int the backend
|
||||
class ZoeDepthPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
|
||||
@@ -255,6 +255,10 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Inductor error for dynamic shape")
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
@@ -15,14 +15,24 @@
|
||||
"""Testing suite for the PyTorch VideoMAE model."""
|
||||
|
||||
import copy
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
from pytest import mark
|
||||
|
||||
from transformers import VideoMAEConfig
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -338,6 +348,59 @@ class VideoMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_flash_attn_2_inference_equivalence(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
inputs_dict["pixel_values"] = inputs_dict["pixel_values"].to(torch.bfloat16)
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
|
||||
model.to(torch_device)
|
||||
|
||||
outputs = model(**inputs_dict, output_hidden_states=True)
|
||||
outputs_fa = model_fa(**inputs_dict, output_hidden_states=True)
|
||||
|
||||
logits = (
|
||||
outputs.hidden_states[-1]
|
||||
if not model.config.is_encoder_decoder
|
||||
else outputs.decoder_hidden_states[-1]
|
||||
)
|
||||
logits_fa = (
|
||||
outputs_fa.hidden_states[-1]
|
||||
if not model.config.is_encoder_decoder
|
||||
else outputs_fa.decoder_hidden_states[-1]
|
||||
)
|
||||
|
||||
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
|
||||
|
||||
# check with inference + dropout
|
||||
model.train()
|
||||
_ = model_fa(**inputs_dict)
|
||||
|
||||
@unittest.skip("Not applicable for VideoMAE")
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
pass
|
||||
|
||||
|
||||
# We will verify our results on a video of eating spaghetti
|
||||
# Frame indices used: [164 168 172 176 181 185 189 193 198 202 206 210 215 219 223 227]
|
||||
|
||||
@@ -19,9 +19,18 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from pytest import mark
|
||||
|
||||
from transformers import ViTMAEConfig
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -269,6 +278,63 @@ class ViTMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
model = ViTMAEModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_flash_attn_2_inference_equivalence(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
inputs_dict["pixel_values"] = inputs_dict["pixel_values"].to(torch.bfloat16)
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
|
||||
model.to(torch_device)
|
||||
|
||||
# ForPretraining model has random `noise` -> need to set seed
|
||||
# to make the test deterministic
|
||||
torch.manual_seed(12345)
|
||||
outputs = model(**inputs_dict, output_hidden_states=True)
|
||||
torch.manual_seed(12345)
|
||||
outputs_fa = model_fa(**inputs_dict, output_hidden_states=True)
|
||||
|
||||
logits = (
|
||||
outputs.hidden_states[-1]
|
||||
if not model.config.is_encoder_decoder
|
||||
else outputs.decoder_hidden_states[-1]
|
||||
)
|
||||
logits_fa = (
|
||||
outputs_fa.hidden_states[-1]
|
||||
if not model.config.is_encoder_decoder
|
||||
else outputs_fa.decoder_hidden_states[-1]
|
||||
)
|
||||
|
||||
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
|
||||
|
||||
# check with inference + dropout
|
||||
model.train()
|
||||
_ = model_fa(**inputs_dict)
|
||||
|
||||
@unittest.skip("Not applicable for VideoMAE")
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
pass
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
|
||||
@@ -130,7 +130,7 @@ class ConfigTester:
|
||||
general_config_dict = config.to_dict()
|
||||
|
||||
# Iterate over all sub_configs if there are any and load them with their own classes
|
||||
sub_configs = self.config_class.sub_configs
|
||||
sub_configs = general_config_loaded.sub_configs
|
||||
for sub_config_key, sub_class in sub_configs.items():
|
||||
if sub_class.__name__ == "AutoConfig":
|
||||
sub_class = sub_class.for_model(**general_config_dict[sub_config_key]).__class__
|
||||
|
||||
@@ -315,8 +315,6 @@ class ModelTesterMixin:
|
||||
return inputs_dict
|
||||
|
||||
def test_save_load(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def check_save_load(out1, out2):
|
||||
# make sure we don't have nans
|
||||
out_2 = out2.cpu().numpy()
|
||||
@@ -330,6 +328,7 @@ class ModelTesterMixin:
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -508,6 +507,7 @@ class ModelTesterMixin:
|
||||
|
||||
@is_flaky(description="low likelihood of failure, reason not yet discovered")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if config.__class__ not in MODEL_MAPPING:
|
||||
self.skipTest(reason=f"{config.__class__.__name__} not in MODEL_MAPPING")
|
||||
@@ -517,7 +517,6 @@ class ModelTesterMixin:
|
||||
if isinstance(base_class, tuple):
|
||||
base_class = base_class[0]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
@@ -2228,9 +2227,9 @@ class ModelTesterMixin:
|
||||
def test_correct_missing_keys(self):
|
||||
if not self.test_missing_keys:
|
||||
self.skipTest(reason="test_missing_keys is set to `False`")
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
base_model_prefix = model.base_model_prefix
|
||||
|
||||
@@ -2287,8 +2286,8 @@ class ModelTesterMixin:
|
||||
|
||||
@require_safetensors
|
||||
def test_can_use_safetensors(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model_tied = model_class(config)
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
try:
|
||||
@@ -2323,9 +2322,9 @@ class ModelTesterMixin:
|
||||
)
|
||||
|
||||
def test_load_save_without_tied_weights(self):
|
||||
for model_class in self.all_model_classes:
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.tie_word_embeddings = False
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
model.save_pretrained(d)
|
||||
@@ -2373,8 +2372,8 @@ class ModelTesterMixin:
|
||||
)
|
||||
|
||||
def test_model_weights_reload_no_missing_tied_weights(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
Reference in New Issue
Block a user