Visual Attention Network (VAN) (#16027)
* encoder works * addded files * norm in stage * convertion script * tests * fix copies * make fix-copies * fixed __init__ * make fix-copies * fix * shapiro test needed * make fix-copie * minor changes * make style + quality * minor refactor conversion script * rebase + tests * removed unused variables * updated doc * toctree * CI * doc * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * resolved conversations * make fixup * config passed to modules * config passed to modules * Apply suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * conversations * conversations * copyrights * normal test * tests Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
8f3ea7a1e1
commit
0a057201a9
@@ -319,6 +319,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (from Microsoft Research) released with the paper [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang.
|
||||
1. **[UniSpeechSat](https://huggingface.co/docs/transformers/model_doc/unispeech-sat)** (from Microsoft Research) released with the paper [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER
|
||||
AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu.
|
||||
1. **[VAN](https://huggingface.co/docs/transformers/master/model_doc/van)** (from Tsinghua University and Nankai University) released with the paper [Visual Attention Network](https://arxiv.org/abs/2202.09741) by Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu.
|
||||
1. **[ViLT](https://huggingface.co/docs/transformers/master/model_doc/vilt)** (from NAVER AI Lab/Kakao Enterprise/Kakao Brain) released with the paper [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) by Wonjae Kim, Bokyung Son, Ildoo Kim.
|
||||
1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
|
||||
1. **[ViTMAE](https://huggingface.co/docs/transformers/master/model_doc/vit_mae)** (from Meta AI) released with the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) by Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick.
|
||||
|
||||
@@ -297,6 +297,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
|
||||
1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
|
||||
1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (from Microsoft Research) released with the paper [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang.
|
||||
1. **[UniSpeechSat](https://huggingface.co/docs/transformers/model_doc/unispeech-sat)** (from Microsoft Research) released with the paper [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu.
|
||||
1. **[VAN](https://huggingface.co/docs/transformers/master/model_doc/van)** (from Tsinghua University and Nankai University) released with the paper [Visual Attention Network](https://arxiv.org/pdf/2202.09741.pdf) by Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu.
|
||||
1. **[ViLT](https://huggingface.co/docs/transformers/master/model_doc/vilt)** (from NAVER AI Lab/Kakao Enterprise/Kakao Brain) released with the paper [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) by Wonjae Kim, Bokyung Son, Ildoo Kim.
|
||||
1. **[ViLT)](https://huggingface.co/docs/transformers/master/model_doc/vilt)** (from NAVER AI Lab/Kakao Enterprise/Kakao Brain) released with the paper [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) by Wonjae Kim, Bokyung Son, Ildoo Kim.
|
||||
1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
|
||||
|
||||
@@ -321,6 +321,7 @@ conda install -c huggingface transformers
|
||||
1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (来自 Microsoft) 伴随论文 [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) 由 Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei 发布。
|
||||
1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (来自 Microsoft Research) 伴随论文 [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) 由 Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang 发布。
|
||||
1. **[UniSpeechSat](https://huggingface.co/docs/transformers/model_doc/unispeech-sat)** (来自 Microsoft Research) 伴随论文 [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) 由 Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu 发布。
|
||||
1. **[VAN](https://huggingface.co/docs/transformers/master/model_doc/van)** (来自 Tsinghua University and Nankai University) 伴随论文 [Visual Attention Network](https://arxiv.org/pdf/2202.09741.pdf) 由 Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu 发布。
|
||||
1. **[ViLT](https://huggingface.co/docs/transformers/master/model_doc/vilt)** (来自 NAVER AI Lab/Kakao Enterprise/Kakao Brain) 伴随论文 [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) 由 Wonjae Kim, Bokyung Son, Ildoo Kim 发布。
|
||||
1. **[ViLT)](https://huggingface.co/docs/transformers/master/model_doc/vilt)** (来自 NAVER AI Lab/Kakao Enterprise/Kakao Brain) 伴随论文 [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) 由 Wonjae Kim, Bokyung Son, Ildoo Kim 发布。
|
||||
1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (来自 Google AI) 伴随论文 [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) 由 Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby 发布。
|
||||
|
||||
@@ -333,6 +333,7 @@ conda install -c huggingface transformers
|
||||
1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft) released with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
|
||||
1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (from Microsoft Research) released with the paper [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang.
|
||||
1. **[UniSpeechSat](https://huggingface.co/docs/transformers/model_doc/unispeech-sat)** (from Microsoft Research) released with the paper [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu.
|
||||
1. **[VAN](https://huggingface.co/docs/transformers/master/model_doc/van)** (from Tsinghua University and Nankai University) released with the paper [Visual Attention Network](https://arxiv.org/pdf/2202.09741.pdf) by Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu.
|
||||
1. **[ViLT](https://huggingface.co/docs/transformers/master/model_doc/vilt)** (from NAVER AI Lab/Kakao Enterprise/Kakao Brain) released with the paper [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) by Wonjae Kim, Bokyung Son, Ildoo Kim.
|
||||
1. **[ViLT)](https://huggingface.co/docs/transformers/master/model_doc/vilt)** (from NAVER AI Lab/Kakao Enterprise/Kakao Brain) released with the paper [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) by Wonjae Kim, Bokyung Son, Ildoo Kim.
|
||||
1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
|
||||
|
||||
@@ -334,6 +334,8 @@
|
||||
title: UniSpeech
|
||||
- local: model_doc/unispeech-sat
|
||||
title: UniSpeech-SAT
|
||||
- local: model_doc/van
|
||||
title: VAN
|
||||
- local: model_doc/vilt
|
||||
title: ViLT
|
||||
- local: model_doc/vision-encoder-decoder
|
||||
|
||||
@@ -142,6 +142,7 @@ conversion utilities for the following models.
|
||||
1. **[TrOCR](model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
|
||||
1. **[UniSpeech](model_doc/unispeech)** (from Microsoft Research) released with the paper [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang.
|
||||
1. **[UniSpeechSat](model_doc/unispeech-sat)** (from Microsoft Research) released with the paper [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu.
|
||||
1. **[VAN](model_doc/van)** (from Tsinghua University and Nankai University) released with the paper [Visual Attention Network](https://arxiv.org/abs/2202.09741) by Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu.
|
||||
1. **[ViLT](model_doc/vilt)** (from NAVER AI Lab/Kakao Enterprise/Kakao Brain) released with the paper [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) by Wonjae Kim, Bokyung Son, Ildoo Kim.
|
||||
1. **[Vision Transformer (ViT)](model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
|
||||
1. **[ViTMAE](model_doc/vit_mae)** (from Meta AI) released with the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) by Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick.
|
||||
@@ -250,6 +251,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| UniSpeech | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| UniSpeechSat | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| VAN | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ViLT | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| Vision Encoder decoder | ❌ | ❌ | ✅ | ✅ | ✅ |
|
||||
| VisionTextDualEncoder | ❌ | ❌ | ✅ | ❌ | ✅ |
|
||||
|
||||
51
docs/source/model_doc/van.mdx
Normal file
51
docs/source/model_doc/van.mdx
Normal file
@@ -0,0 +1,51 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# VAN
|
||||
|
||||
## Overview
|
||||
|
||||
The VAN model was proposed in [Visual Attention Network](https://arxiv.org/abs/2202.09741) by Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu.
|
||||
|
||||
This paper introduces a new attention layer based on convolution operations able to capture both local and distant relationships. This is done by combining normal and large kernel convolution layers. The latter uses a dilated convolution to capture distant correlations.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*While originally designed for natural language processing tasks, the self-attention mechanism has recently taken various computer vision areas by storm. However, the 2D nature of images brings three challenges for applying self-attention in computer vision. (1) Treating images as 1D sequences neglects their 2D structures. (2) The quadratic complexity is too expensive for high-resolution images. (3) It only captures spatial adaptability but ignores channel adaptability. In this paper, we propose a novel large kernel attention (LKA) module to enable self-adaptive and long-range correlations in self-attention while avoiding the above issues. We further introduce a novel neural network based on LKA, namely Visual Attention Network (VAN). While extremely simple, VAN outperforms the state-of-the-art vision transformers and convolutional neural networks with a large margin in extensive experiments, including image classification, object detection, semantic segmentation, instance segmentation, etc. Code is available at [this https URL](https://github.com/Visual-Attention-Network/VAN-Classification).*
|
||||
|
||||
Tips:
|
||||
|
||||
- VAN does not have an embedding layer, thus the `hidden_states` will have a length equal to the number of stages.
|
||||
|
||||
The figure below illustrates the architecture of a Visual Aattention Layer. Taken from the [original paper](https://arxiv.org/abs/2202.09741).
|
||||
|
||||
<img width="600" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/van_architecture.png"/>
|
||||
|
||||
This model was contributed by [Francesco](https://huggingface.co/Francesco). The original code can be found [here](https://github.com/Visual-Attention-Network/VAN-Classification).
|
||||
|
||||
|
||||
## VanConfig
|
||||
|
||||
[[autodoc]] VanConfig
|
||||
|
||||
|
||||
## VanModel
|
||||
|
||||
[[autodoc]] VanModel
|
||||
- forward
|
||||
|
||||
|
||||
## VanForImageClassification
|
||||
|
||||
[[autodoc]] VanForImageClassification
|
||||
- forward
|
||||
|
||||
@@ -314,6 +314,7 @@ _import_structure = {
|
||||
"UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"UniSpeechSatConfig",
|
||||
],
|
||||
"models.van": ["VAN_PRETRAINED_CONFIG_ARCHIVE_MAP", "VanConfig"],
|
||||
"models.vilt": ["VILT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViltConfig", "ViltFeatureExtractor", "ViltProcessor"],
|
||||
"models.vision_encoder_decoder": ["VisionEncoderDecoderConfig"],
|
||||
"models.vision_text_dual_encoder": ["VisionTextDualEncoderConfig", "VisionTextDualEncoderProcessor"],
|
||||
@@ -1479,6 +1480,14 @@ if is_torch_available():
|
||||
"UniSpeechSatPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.van"].extend(
|
||||
[
|
||||
"VAN_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"VanForImageClassification",
|
||||
"VanModel",
|
||||
"VanPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.vilt"].extend(
|
||||
[
|
||||
"VILT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@@ -2612,6 +2621,7 @@ if TYPE_CHECKING:
|
||||
from .models.trocr import TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP, TrOCRConfig, TrOCRProcessor
|
||||
from .models.unispeech import UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechConfig
|
||||
from .models.unispeech_sat import UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechSatConfig
|
||||
from .models.van import VAN_PRETRAINED_CONFIG_ARCHIVE_MAP, VanConfig
|
||||
from .models.vilt import VILT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViltConfig, ViltFeatureExtractor, ViltProcessor
|
||||
from .models.vision_encoder_decoder import VisionEncoderDecoderConfig
|
||||
from .models.vision_text_dual_encoder import VisionTextDualEncoderConfig, VisionTextDualEncoderProcessor
|
||||
@@ -3585,6 +3595,12 @@ if TYPE_CHECKING:
|
||||
UniSpeechSatModel,
|
||||
UniSpeechSatPreTrainedModel,
|
||||
)
|
||||
from .models.van import (
|
||||
VAN_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
VanForImageClassification,
|
||||
VanModel,
|
||||
VanPreTrainedModel,
|
||||
)
|
||||
from .models.vilt import (
|
||||
VILT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
ViltForImageAndTextRetrieval,
|
||||
|
||||
@@ -113,6 +113,7 @@ from . import (
|
||||
trocr,
|
||||
unispeech,
|
||||
unispeech_sat,
|
||||
van,
|
||||
vilt,
|
||||
vision_encoder_decoder,
|
||||
vision_text_dual_encoder,
|
||||
|
||||
@@ -33,6 +33,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("maskformer", "MaskFormerConfig"),
|
||||
("poolformer", "PoolFormerConfig"),
|
||||
("convnext", "ConvNextConfig"),
|
||||
("van", "VanConfig"),
|
||||
("resnet", "ResNetConfig"),
|
||||
("yoso", "YosoConfig"),
|
||||
("swin", "SwinConfig"),
|
||||
@@ -134,6 +135,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
||||
("maskformer", "MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("poolformer", "POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("convnext", "CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("van", "VAN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("resnet", "RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("yoso", "YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("swin", "SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
@@ -222,6 +224,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("maskformer", "MaskFormer"),
|
||||
("poolformer", "PoolFormer"),
|
||||
("convnext", "ConvNext"),
|
||||
("van", "VAN"),
|
||||
("resnet", "ResNet"),
|
||||
("yoso", "YOSO"),
|
||||
("swin", "Swin"),
|
||||
|
||||
@@ -53,6 +53,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("vit_mae", "ViTFeatureExtractor"),
|
||||
("segformer", "SegformerFeatureExtractor"),
|
||||
("convnext", "ConvNextFeatureExtractor"),
|
||||
("van", "ConvNextFeatureExtractor"),
|
||||
("resnet", "ConvNextFeatureExtractor"),
|
||||
("poolformer", "PoolFormerFeatureExtractor"),
|
||||
("maskformer", "MaskFormerFeatureExtractor"),
|
||||
|
||||
@@ -31,6 +31,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("maskformer", "MaskFormerModel"),
|
||||
("poolformer", "PoolFormerModel"),
|
||||
("convnext", "ConvNextModel"),
|
||||
("van", "VanModel"),
|
||||
("resnet", "ResNetModel"),
|
||||
("yoso", "YosoModel"),
|
||||
("swin", "SwinModel"),
|
||||
@@ -295,6 +296,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
),
|
||||
("swin", "SwinForImageClassification"),
|
||||
("convnext", "ConvNextForImageClassification"),
|
||||
("van", "VanForImageClassification"),
|
||||
("resnet", "ResNetForImageClassification"),
|
||||
("poolformer", "PoolFormerForImageClassification"),
|
||||
]
|
||||
|
||||
51
src/transformers/models/van/__init__.py
Normal file
51
src/transformers/models/van/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# rely on isort to merge the imports
|
||||
from ...file_utils import _LazyModule, is_torch_available, is_vision_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_van": ["VAN_PRETRAINED_CONFIG_ARCHIVE_MAP", "VanConfig"],
|
||||
}
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_van"] = [
|
||||
"VAN_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"VanForImageClassification",
|
||||
"VanModel",
|
||||
"VanPreTrainedModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_van import VAN_PRETRAINED_CONFIG_ARCHIVE_MAP, VanConfig
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_van import (
|
||||
VAN_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
VanForImageClassification,
|
||||
VanModel,
|
||||
VanPreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
|
||||
110
src/transformers/models/van/configuration_van.py
Normal file
110
src/transformers/models/van/configuration_van.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" VAN model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VAN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"van-base": "https://huggingface.co/Visual-Attention-Network/van-base/blob/main/config.json",
|
||||
}
|
||||
|
||||
|
||||
class VanConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`VanModel`]. It is used to instantiate a VAN model
|
||||
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the VAN [van-base](https://huggingface.co/van-base)
|
||||
architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
image_size (`int`, *optional*, defaults to 224):
|
||||
The size (resolution) of each image.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
patch_sizes (`List[int]`, *optional*, defaults to `[7, 3, 3, 3]`):
|
||||
Patch size to use in each stage's embedding layer.
|
||||
strides (`List[int]`, *optional*, defaults to `[4, 2, 2, 2]`):
|
||||
Stride size to use in each stage's embedding layer to downsample the input.
|
||||
hidden_sizes (`List[int]`, *optional*, defaults to `[64, 128, 320, 512]`):
|
||||
Dimensionality (hidden size) at each stage.
|
||||
depths (`List[int]`, *optional*, defaults to `[3, 3, 12, 3]`):
|
||||
Depth (number of layers) for each stage.
|
||||
mlp_ratios (`List[int]`, *optional*, defaults to `[8, 8, 4, 4]`):
|
||||
The expansion ratio for mlp layer at each stage.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in each layer. If string, `"gelu"`, `"relu"`,
|
||||
`"selu"` and `"gelu_new"` are supported.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
layer_scale_init_value (`float`, *optional*, defaults to 1e-2):
|
||||
The initial value for layer scaling.
|
||||
drop_path_rate (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for stochastic depth.
|
||||
dropout_rate (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for dropout.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import VanModel, VanConfig
|
||||
|
||||
>>> # Initializing a VAN van-base style configuration
|
||||
>>> configuration = VanConfig()
|
||||
>>> # Initializing a model from the van-base style configuration
|
||||
>>> model = VanModel(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
model_type = "van"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size=224,
|
||||
num_channels=3,
|
||||
patch_sizes=[7, 3, 3, 3],
|
||||
strides=[4, 2, 2, 2],
|
||||
hidden_sizes=[64, 128, 320, 512],
|
||||
depths=[3, 3, 12, 3],
|
||||
mlp_ratios=[8, 8, 4, 4],
|
||||
hidden_act="gelu",
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-6,
|
||||
layer_scale_init_value=1e-2,
|
||||
drop_path_rate=0.0,
|
||||
dropout_rate=0.0,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.image_size = image_size
|
||||
self.num_channels = num_channels
|
||||
self.patch_sizes = patch_sizes
|
||||
self.strides = strides
|
||||
self.hidden_sizes = hidden_sizes
|
||||
self.depths = depths
|
||||
self.mlp_ratios = mlp_ratios
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.layer_scale_init_value = layer_scale_init_value
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.dropout_rate = dropout_rate
|
||||
275
src/transformers/models/van/convert_van_to_pytorch.py
Normal file
275
src/transformers/models/van/convert_van_to_pytorch.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 BNRist (Tsinghua University), TKLNDST (Nankai University) and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert VAN checkpoints from the original repository.
|
||||
|
||||
URL: https://github.com/Visual-Attention-Network/VAN-Classification"""
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from huggingface_hub import cached_download, hf_hub_url
|
||||
from transformers import AutoFeatureExtractor, VanConfig, VanForImageClassification
|
||||
from transformers.models.van.modeling_van import VanLayerScaling
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tracker:
|
||||
module: nn.Module
|
||||
traced: List[nn.Module] = field(default_factory=list)
|
||||
handles: list = field(default_factory=list)
|
||||
|
||||
def _forward_hook(self, m, inputs: Tensor, outputs: Tensor):
|
||||
has_not_submodules = len(list(m.modules())) == 1 or isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d)
|
||||
if has_not_submodules:
|
||||
if not isinstance(m, VanLayerScaling):
|
||||
self.traced.append(m)
|
||||
|
||||
def __call__(self, x: Tensor):
|
||||
for m in self.module.modules():
|
||||
self.handles.append(m.register_forward_hook(self._forward_hook))
|
||||
self.module(x)
|
||||
list(map(lambda x: x.remove(), self.handles))
|
||||
return self
|
||||
|
||||
@property
|
||||
def parametrized(self):
|
||||
# check the len of the state_dict keys to see if we have learnable params
|
||||
return list(filter(lambda x: len(list(x.state_dict().keys())) > 0, self.traced))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModuleTransfer:
|
||||
src: nn.Module
|
||||
dest: nn.Module
|
||||
verbose: int = 0
|
||||
src_skip: List = field(default_factory=list)
|
||||
dest_skip: List = field(default_factory=list)
|
||||
|
||||
def __call__(self, x: Tensor):
|
||||
"""
|
||||
Transfer the weights of `self.src` to `self.dest` by performing a forward pass using `x` as input. Under the
|
||||
hood we tracked all the operations in both modules.
|
||||
"""
|
||||
dest_traced = Tracker(self.dest)(x).parametrized
|
||||
src_traced = Tracker(self.src)(x).parametrized
|
||||
|
||||
src_traced = list(filter(lambda x: type(x) not in self.src_skip, src_traced))
|
||||
dest_traced = list(filter(lambda x: type(x) not in self.dest_skip, dest_traced))
|
||||
|
||||
if len(dest_traced) != len(src_traced):
|
||||
raise Exception(
|
||||
f"Numbers of operations are different. Source module has {len(src_traced)} operations while destination module has {len(dest_traced)}."
|
||||
)
|
||||
|
||||
for dest_m, src_m in zip(dest_traced, src_traced):
|
||||
dest_m.load_state_dict(src_m.state_dict())
|
||||
if self.verbose == 1:
|
||||
print(f"Transfered from={src_m} to={dest_m}")
|
||||
|
||||
|
||||
def copy_parameters(from_model: nn.Module, our_model: nn.Module) -> nn.Module:
|
||||
# nn.Parameter cannot be tracked by the Tracker, thus we need to manually convert them
|
||||
from_state_dict = from_model.state_dict()
|
||||
our_state_dict = our_model.state_dict()
|
||||
config = our_model.config
|
||||
all_keys = []
|
||||
for stage_idx in range(len(config.hidden_sizes)):
|
||||
for block_id in range(config.depths[stage_idx]):
|
||||
from_key = f"block{stage_idx + 1}.{block_id}.layer_scale_1"
|
||||
to_key = f"van.encoder.stages.{stage_idx}.layers.{block_id}.attention_scaling.weight"
|
||||
|
||||
all_keys.append((from_key, to_key))
|
||||
from_key = f"block{stage_idx + 1}.{block_id}.layer_scale_2"
|
||||
to_key = f"van.encoder.stages.{stage_idx}.layers.{block_id}.mlp_scaling.weight"
|
||||
|
||||
all_keys.append((from_key, to_key))
|
||||
|
||||
for from_key, to_key in all_keys:
|
||||
our_state_dict[to_key] = from_state_dict.pop(from_key)
|
||||
|
||||
our_model.load_state_dict(our_state_dict)
|
||||
return our_model
|
||||
|
||||
|
||||
def convert_weight_and_push(
|
||||
name: str,
|
||||
config: VanConfig,
|
||||
checkpoint: str,
|
||||
from_model: nn.Module,
|
||||
save_directory: Path,
|
||||
push_to_hub: bool = True,
|
||||
):
|
||||
print(f"Downloading weights for {name}...")
|
||||
checkpoint_path = cached_download(checkpoint)
|
||||
print(f"Converting {name}...")
|
||||
from_state_dict = torch.load(checkpoint_path)["state_dict"]
|
||||
from_model.load_state_dict(from_state_dict)
|
||||
from_model.eval()
|
||||
with torch.no_grad():
|
||||
our_model = VanForImageClassification(config).eval()
|
||||
module_transfer = ModuleTransfer(src=from_model, dest=our_model)
|
||||
x = torch.randn((1, 3, 224, 224))
|
||||
module_transfer(x)
|
||||
our_model = copy_parameters(from_model, our_model)
|
||||
|
||||
assert torch.allclose(from_model(x), our_model(x).logits), "The model logits don't match the original one."
|
||||
|
||||
checkpoint_name = name
|
||||
print(checkpoint_name)
|
||||
|
||||
if push_to_hub:
|
||||
our_model.push_to_hub(
|
||||
repo_path_or_name=save_directory / checkpoint_name,
|
||||
commit_message="Add model",
|
||||
use_temp_dir=True,
|
||||
)
|
||||
|
||||
# we can use the convnext one
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/convnext-base-224-22k-1k")
|
||||
feature_extractor.push_to_hub(
|
||||
repo_path_or_name=save_directory / checkpoint_name,
|
||||
commit_message="Add feature extractor",
|
||||
use_temp_dir=True,
|
||||
)
|
||||
|
||||
print(f"Pushed {checkpoint_name}")
|
||||
|
||||
|
||||
def convert_weights_and_push(save_directory: Path, model_name: str = None, push_to_hub: bool = True):
|
||||
filename = "imagenet-1k-id2label.json"
|
||||
num_labels = 1000
|
||||
|
||||
repo_id = "datasets/huggingface/label-files"
|
||||
num_labels = num_labels
|
||||
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
|
||||
id2label = id2label
|
||||
label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
ImageNetPreTrainedConfig = partial(VanConfig, num_labels=num_labels, id2label=id2label, label2id=label2id)
|
||||
|
||||
names_to_config = {
|
||||
"van-tiny": ImageNetPreTrainedConfig(
|
||||
hidden_sizes=[32, 64, 160, 256],
|
||||
depths=[3, 3, 5, 2],
|
||||
mlp_ratios=[8, 8, 4, 4],
|
||||
),
|
||||
"van-small": ImageNetPreTrainedConfig(
|
||||
hidden_sizes=[64, 128, 320, 512],
|
||||
depths=[2, 2, 4, 2],
|
||||
mlp_ratios=[8, 8, 4, 4],
|
||||
),
|
||||
"van-base": ImageNetPreTrainedConfig(
|
||||
hidden_sizes=[64, 128, 320, 512],
|
||||
depths=[3, 3, 12, 3],
|
||||
mlp_ratios=[8, 8, 4, 4],
|
||||
),
|
||||
"van-large": ImageNetPreTrainedConfig(
|
||||
hidden_sizes=[64, 128, 320, 512],
|
||||
depths=[3, 5, 27, 3],
|
||||
mlp_ratios=[8, 8, 4, 4],
|
||||
),
|
||||
}
|
||||
|
||||
names_to_original_models = {
|
||||
"van-tiny": van_tiny,
|
||||
"van-small": van_small,
|
||||
"van-base": van_base,
|
||||
"van-large": van_large,
|
||||
}
|
||||
|
||||
names_to_original_checkpoints = {
|
||||
"van-tiny": "https://huggingface.co/Visual-Attention-Network/VAN-Tiny/resolve/main/van_tiny_754.pth.tar",
|
||||
"van-small": "https://huggingface.co/Visual-Attention-Network/VAN-Small/resolve/main/van_small_811.pth.tar",
|
||||
"van-base": "https://huggingface.co/Visual-Attention-Network/VAN-Base/resolve/main/van_base_828.pth.tar",
|
||||
"van-large": "https://huggingface.co/Visual-Attention-Network/VAN-Large/resolve/main/van_large_839.pth.tar",
|
||||
}
|
||||
|
||||
if model_name:
|
||||
convert_weight_and_push(
|
||||
model_name,
|
||||
names_to_config[model_name],
|
||||
checkpoint=names_to_original_checkpoints[model_name],
|
||||
from_model=names_to_original_models[model_name](),
|
||||
save_directory=save_directory,
|
||||
push_to_hub=push_to_hub,
|
||||
)
|
||||
else:
|
||||
for model_name, config in names_to_config.items():
|
||||
convert_weight_and_push(
|
||||
model_name,
|
||||
config,
|
||||
checkpoint=names_to_original_checkpoints[model_name],
|
||||
from_model=names_to_original_models[model_name](),
|
||||
save_directory=save_directory,
|
||||
push_to_hub=push_to_hub,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The name of the model you wish to convert, it must be one of the supported resnet* architecture, currently: van-tiny/small/base/large. If `None`, all of them will the converted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--van_dir",
|
||||
required=True,
|
||||
type=Path,
|
||||
help="A path to VAN's original implementation directory. You can download from here: https://github.com/Visual-Attention-Network/VAN-Classification",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
default=True,
|
||||
type=bool,
|
||||
required=False,
|
||||
help="If True, push model and feature extractor to the hub.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
pytorch_dump_folder_path: Path = args.pytorch_dump_folder_path
|
||||
pytorch_dump_folder_path.mkdir(exist_ok=True, parents=True)
|
||||
van_dir = args.van_dir
|
||||
# append the path to the parents to maskformer dir
|
||||
sys.path.append(str(van_dir.parent))
|
||||
from van.models.van import van_base, van_large, van_small, van_tiny
|
||||
|
||||
convert_weights_and_push(pytorch_dump_folder_path, args.model_name, args.push_to_hub)
|
||||
566
src/transformers/models/van/modeling_van.py
Normal file
566
src/transformers/models/van/modeling_van.py
Normal file
@@ -0,0 +1,566 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 BNRist (Tsinghua University), TKLNDST (Nankai University) and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Visual Attention Network (VAN) model."""
|
||||
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
from .configuration_van import VanConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# General docstring
|
||||
_CONFIG_FOR_DOC = "VanConfig"
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor"
|
||||
|
||||
# Base docstring
|
||||
_CHECKPOINT_FOR_DOC = "van-base"
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 512, 7, 7]
|
||||
|
||||
# Image classification docstring
|
||||
_IMAGE_CLASS_CHECKPOINT = "van-base"
|
||||
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
|
||||
|
||||
VAN_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"Visual-Attention-Network/van-base",
|
||||
# See all VAN models at https://huggingface.co/models?filter=van
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class VanEncoderOutput(ModelOutput):
|
||||
"""
|
||||
Class for [`VanEncoder`]'s outputs, with potential hidden states (feature maps).
|
||||
|
||||
Args:
|
||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Last hidden states (final feature map) of the last stage of the model.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, num_channels,
|
||||
height, width)`. Hidden-states (also called feature maps) of the model at the output of each stage.
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VanModelOutput(ModelOutput):
|
||||
"""
|
||||
Class for [`VanModel`]'s outputs, with potential hidden states (feature maps).
|
||||
|
||||
Args:
|
||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Last hidden states (final feature map) of the last stage of the model.
|
||||
pooler_output (`torch.FloatTensor` of shape `(batch_size, config.hidden_sizes[-1])`):
|
||||
Global average pooling of the last feature map followed by a layernorm.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, num_channels,
|
||||
height, width)`. Hidden-states (also called feature maps) of the model at the output of each stage.
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
pooler_output: Optional[torch.FloatTensor] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VanClassifierOutput(ModelOutput):
|
||||
"""
|
||||
Class for [`VanForImageClassification`]'s outputs, with potential hidden states (feature maps).
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, num_channels,
|
||||
height, width)`. Hidden-states (also called feature maps) of the model at the output of each stage.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
# Stochastic depth implementation
|
||||
# Taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
|
||||
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
||||
"""
|
||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the
|
||||
DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop
|
||||
Connect' is a different form of dropout in a separate paper... See discussion:
|
||||
https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and
|
||||
argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument.
|
||||
"""
|
||||
if drop_prob == 0.0 or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = x.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
# Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath with ConvNext->Van
|
||||
class VanDropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super().__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
|
||||
class VanOverlappingPatchEmbedder(nn.Sequential):
|
||||
"""
|
||||
Downsamples the input using a patchify operation with a `stride` of 4 by default making adjacent windows overlap by
|
||||
half of the area. From [PVTv2: Improved Baselines with Pyramid Vision
|
||||
Transformer](https://arxiv.org/abs/2106.13797).
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, hidden_size: int, patch_size: int = 7, stride: int = 4):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=stride, padding=patch_size // 2)
|
||||
self.norm = nn.BatchNorm2d(hidden_size)
|
||||
|
||||
|
||||
class VanMlpLayer(nn.Sequential):
|
||||
"""
|
||||
MLP with depth-wise convolution, from [PVTv2: Improved Baselines with Pyramid Vision
|
||||
Transformer](https://arxiv.org/abs/2106.13797).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
hidden_size: int,
|
||||
out_channels: int,
|
||||
hidden_act: str = "gelu",
|
||||
dropout_rate: float = 0.5,
|
||||
):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Conv2d(in_channels, hidden_size, kernel_size=1)
|
||||
self.depth_wise = nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1, groups=hidden_size)
|
||||
self.activation = ACT2FN[hidden_act]
|
||||
self.drop1 = nn.Dropout(dropout_rate)
|
||||
self.fc2 = nn.Conv2d(hidden_size, out_channels, kernel_size=1)
|
||||
self.drop2 = nn.Dropout(dropout_rate)
|
||||
|
||||
|
||||
class VanLargeKernelAttention(nn.Sequential):
|
||||
"""
|
||||
Basic Large Kernel Attention (LKA).
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int):
|
||||
super().__init__()
|
||||
self.depth_wise = nn.Conv2d(hidden_size, hidden_size, kernel_size=5, padding=2, groups=hidden_size)
|
||||
self.depth_wise_dilated = nn.Conv2d(
|
||||
hidden_size, hidden_size, kernel_size=7, dilation=3, padding=9, groups=hidden_size
|
||||
)
|
||||
self.point_wise = nn.Conv2d(hidden_size, hidden_size, kernel_size=1)
|
||||
|
||||
|
||||
class VanLargeKernelAttentionLayer(nn.Module):
|
||||
"""
|
||||
Computes attention using Large Kernel Attention (LKA) and attends the input.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int):
|
||||
super().__init__()
|
||||
self.attention = VanLargeKernelAttention(hidden_size)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
attention = self.attention(hidden_state)
|
||||
attended = hidden_state * attention
|
||||
return attended
|
||||
|
||||
|
||||
class VanSpatialAttentionLayer(nn.Module):
|
||||
"""
|
||||
Van spatial attention layer composed by projection (via conv) -> act -> Large Kernel Attention (LKA) attention ->
|
||||
projection (via conv) + residual connection.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, hidden_act: str = "gelu"):
|
||||
super().__init__()
|
||||
self.pre_projection = nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
("conv", nn.Conv2d(hidden_size, hidden_size, kernel_size=1)),
|
||||
("act", ACT2FN[hidden_act]),
|
||||
]
|
||||
)
|
||||
)
|
||||
self.attention_layer = VanLargeKernelAttentionLayer(hidden_size)
|
||||
self.post_projection = nn.Conv2d(hidden_size, hidden_size, kernel_size=1)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
residual = hidden_state
|
||||
hidden_state = self.pre_projection(hidden_state)
|
||||
hidden_state = self.attention_layer(hidden_state)
|
||||
hidden_state = self.post_projection(hidden_state)
|
||||
hidden_state = hidden_state + residual
|
||||
return hidden_state
|
||||
|
||||
|
||||
class VanLayerScaling(nn.Module):
|
||||
"""
|
||||
Scales the inputs by a learnable parameter initialized by `initial_value`.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, initial_value: float = 1e-2):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(initial_value * torch.ones((hidden_size)), requires_grad=True)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
# unsqueezing for broadcasting
|
||||
hidden_state = self.weight.unsqueeze(-1).unsqueeze(-1) * hidden_state
|
||||
return hidden_state
|
||||
|
||||
|
||||
class VanLayer(nn.Module):
|
||||
"""
|
||||
Van layer composed by normalization layers, large kernel attention (LKA) and a multi layer perceptron (MLP).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: VanConfig,
|
||||
hidden_size: int,
|
||||
mlp_ratio: int = 4,
|
||||
drop_path_rate: float = 0.5,
|
||||
):
|
||||
super().__init__()
|
||||
self.drop_path = VanDropPath(drop_path) if drop_path_rate > 0.0 else nn.Identity()
|
||||
self.pre_norm = nn.BatchNorm2d(hidden_size)
|
||||
self.attention = VanSpatialAttentionLayer(hidden_size, config.hidden_act)
|
||||
self.attention_scaling = VanLayerScaling(hidden_size, config.layer_scale_init_value)
|
||||
self.post_norm = nn.BatchNorm2d(hidden_size)
|
||||
self.mlp = VanMlpLayer(
|
||||
hidden_size, hidden_size * mlp_ratio, hidden_size, config.hidden_act, config.dropout_rate
|
||||
)
|
||||
self.mlp_scaling = VanLayerScaling(hidden_size, config.layer_scale_init_value)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
residual = hidden_state
|
||||
# attention
|
||||
hidden_state = self.pre_norm(hidden_state)
|
||||
hidden_state = self.attention(hidden_state)
|
||||
hidden_state = self.attention_scaling(hidden_state)
|
||||
hidden_state = self.drop_path(hidden_state)
|
||||
# residual connection
|
||||
hidden_state = residual + hidden_state
|
||||
residual = hidden_state
|
||||
# mlp
|
||||
hidden_state = self.post_norm(hidden_state)
|
||||
hidden_state = self.mlp(hidden_state)
|
||||
hidden_state = self.mlp_scaling(hidden_state)
|
||||
hidden_state = self.drop_path(hidden_state)
|
||||
# residual connection
|
||||
hidden_state = residual + hidden_state
|
||||
return hidden_state
|
||||
|
||||
|
||||
class VanStage(nn.Module):
|
||||
"""
|
||||
VanStage, consisting of multiple layers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: VanConfig,
|
||||
in_channels: int,
|
||||
hidden_size: int,
|
||||
patch_size: int,
|
||||
stride: int,
|
||||
depth: int,
|
||||
mlp_ratio: int = 4,
|
||||
drop_path_rate: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.embeddings = VanOverlappingPatchEmbedder(in_channels, hidden_size, patch_size, stride)
|
||||
self.layers = nn.Sequential(
|
||||
*[
|
||||
VanLayer(
|
||||
config,
|
||||
hidden_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
drop_path_rate=drop_path_rate,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
self.norm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
hidden_state = self.embeddings(hidden_state)
|
||||
hidden_state = self.layers(hidden_state)
|
||||
# rearrange b c h w -> b (h w) c
|
||||
batch_size, hidden_size, height, width = hidden_state.shape
|
||||
hidden_state = hidden_state.flatten(2).transpose(1, 2)
|
||||
hidden_state = self.norm(hidden_state)
|
||||
# rearrange b (h w) c- > b c h w
|
||||
hidden_state = hidden_state.view(batch_size, height, width, hidden_size).permute(0, 3, 1, 2)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class VanEncoder(nn.Module):
|
||||
"""
|
||||
VanEncoder, consisting of multiple stages.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VanConfig):
|
||||
super().__init__()
|
||||
self.stages = nn.ModuleList([])
|
||||
patch_sizes = config.patch_sizes
|
||||
strides = config.strides
|
||||
hidden_sizes = config.hidden_sizes
|
||||
depths = config.depths
|
||||
mlp_ratios = config.mlp_ratios
|
||||
drop_path_rates = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
|
||||
|
||||
for num_stage, (patch_size, stride, hidden_size, depth, mlp_expantion, drop_path_rate) in enumerate(
|
||||
zip(patch_sizes, strides, hidden_sizes, depths, mlp_ratios, drop_path_rates)
|
||||
):
|
||||
is_first_stage = num_stage == 0
|
||||
in_channels = hidden_sizes[num_stage - 1]
|
||||
if is_first_stage:
|
||||
in_channels = config.num_channels
|
||||
self.stages.append(
|
||||
VanStage(
|
||||
config,
|
||||
in_channels,
|
||||
hidden_size,
|
||||
patch_size=patch_size,
|
||||
stride=stride,
|
||||
depth=depth,
|
||||
mlp_ratio=mlp_expantion,
|
||||
drop_path_rate=drop_path_rate,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, hidden_state, output_hidden_states=False, return_dict=True):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
for _, stage_module in enumerate(self.stages):
|
||||
hidden_state = stage_module(hidden_state)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_state,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_state, all_hidden_states] if v is not None)
|
||||
|
||||
return VanEncoderOutput(last_hidden_state=hidden_state, hidden_states=all_hidden_states)
|
||||
|
||||
|
||||
class VanPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = VanConfig
|
||||
base_model_prefix = "van"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.trunc_normal_(module.weight, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
nn.init.constant_(module.bias, 0)
|
||||
nn.init.constant_(module.weight, 1.0)
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
|
||||
fan_out //= module.groups
|
||||
module.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, VanModel):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
VAN_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||||
behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`VanConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
VAN_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
|
||||
[`AutoFeatureExtractor.__call__`] for details.
|
||||
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all stages. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare VAN model outputting raw features without any specific head on top. Note, VAN does not have an embedding layer.",
|
||||
VAN_START_DOCSTRING,
|
||||
)
|
||||
class VanModel(VanPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.encoder = VanEncoder(config)
|
||||
# final layernorm layer
|
||||
self.layernorm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(VAN_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=VanModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="vision",
|
||||
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||
)
|
||||
def forward(self, pixel_values, output_hidden_states=None, return_dict=None):
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
pixel_values,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
# global average pooling, n c w h -> n c
|
||||
pooled_output = last_hidden_state.mean(dim=[-2, -1])
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return VanModelOutput(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
VAN Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
|
||||
ImageNet.
|
||||
""",
|
||||
VAN_START_DOCSTRING,
|
||||
)
|
||||
class VanForImageClassification(VanPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.van = VanModel(config)
|
||||
# Classifier head
|
||||
self.classifier = (
|
||||
nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
|
||||
)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(VAN_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
||||
output_type=VanClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
||||
)
|
||||
def forward(self, pixel_values=None, labels=None, output_hidden_states=None, return_dict=None):
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.van(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
|
||||
|
||||
pooled_output = outputs.pooler_output if return_dict else outputs[1]
|
||||
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.config.problem_type is None:
|
||||
if self.config.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.config.num_labels == 1:
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return VanClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
|
||||
@@ -44,6 +44,7 @@ from .file_utils import (
|
||||
is_pytorch_quantization_available,
|
||||
is_rjieba_available,
|
||||
is_scatter_available,
|
||||
is_scipy_available,
|
||||
is_sentencepiece_available,
|
||||
is_soundfile_availble,
|
||||
is_spacy_available,
|
||||
@@ -351,6 +352,16 @@ def require_sentencepiece(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_scipy(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires Scipy. These tests are skipped when SentencePiece isn't installed.
|
||||
"""
|
||||
if not is_scipy_available():
|
||||
return unittest.skip("test requires Scipy")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_tokenizers(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires 🤗 Tokenizers. These tests are skipped when 🤗 Tokenizers isn't installed.
|
||||
|
||||
@@ -3961,6 +3961,30 @@ class UniSpeechSatPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
VAN_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class VanForImageClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class VanModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class VanPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
VILT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
||||
0
tests/van/__init__.py
Normal file
0
tests/van/__init__.py
Normal file
273
tests/van/test_modeling_van.py
Normal file
273
tests/van/test_modeling_van.py
Normal file
@@ -0,0 +1,273 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch Van model. """
|
||||
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import unittest
|
||||
|
||||
from transformers import VanConfig
|
||||
from transformers.file_utils import cached_property, is_scipy_available, is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_scipy, require_torch, require_vision, slow, torch_device
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
from scipy import stats
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import VanForImageClassification, VanModel
|
||||
from transformers.models.van.modeling_van import VAN_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
|
||||
class VanModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=2,
|
||||
image_size=224,
|
||||
num_channels=3,
|
||||
hidden_sizes=[16, 32, 64, 128],
|
||||
depths=[1, 1, 1, 1],
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
num_labels=3,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.num_channels = num_channels
|
||||
self.hidden_sizes = hidden_sizes
|
||||
self.depths = depths
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.num_labels = num_labels
|
||||
self.type_sequence_label_size = num_labels
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
labels = None
|
||||
if self.use_labels:
|
||||
labels = ids_tensor([self.batch_size], self.num_labels)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values, labels
|
||||
|
||||
def get_config(self):
|
||||
return VanConfig(
|
||||
num_channels=self.num_channels,
|
||||
hidden_sizes=self.hidden_sizes,
|
||||
depths=self.depths,
|
||||
num_labels=self.num_labels,
|
||||
is_decoder=False,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
model = VanModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
# expected last hidden states: B, C, H // 32, W // 32
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape,
|
||||
(self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32),
|
||||
)
|
||||
|
||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||
model = VanForImageClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values, labels=labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values, labels = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class VanModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Here we also overwrite some of the tests of test_modeling_common.py, as Van does not use input_ids, inputs_embeds,
|
||||
attention_mask and seq_length.
|
||||
"""
|
||||
|
||||
all_model_classes = (VanModel, VanForImageClassification) if is_torch_available() else ()
|
||||
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
has_attentions = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = VanModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=VanConfig, has_text_modality=False, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.create_and_test_config_common_properties()
|
||||
self.config_tester.create_and_test_config_to_json_string()
|
||||
self.config_tester.create_and_test_config_to_json_file()
|
||||
self.config_tester.create_and_test_config_from_and_save_pretrained()
|
||||
self.config_tester.create_and_test_config_with_num_labels()
|
||||
self.config_tester.check_config_can_be_init_without_params()
|
||||
self.config_tester.check_config_arguments_init()
|
||||
|
||||
def create_and_test_config_common_properties(self):
|
||||
return
|
||||
|
||||
@unittest.skip(reason="Van does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Van does not support input and output embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@require_scipy
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
configs_no_init = _config_zero_init(config)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)):
|
||||
self.assertTrue(
|
||||
torch.all(module.weight == 1),
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(module.bias == 0),
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
|
||||
fan_out //= module.groups
|
||||
std = math.sqrt(2.0 / fan_out)
|
||||
# divide by std -> mean = 0, std = 1
|
||||
data = module.weight.data.cpu().flatten().numpy() / std
|
||||
test = stats.anderson(data)
|
||||
self.assertTrue(test.statistic > 0.05)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
||||
|
||||
expected_num_stages = len(self.model_tester.hidden_sizes)
|
||||
# van has no embeddings
|
||||
self.assertEqual(len(hidden_states), expected_num_stages)
|
||||
|
||||
# Van's feature maps are of shape (batch_size, num_channels, height, width)
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[self.model_tester.image_size // 4, self.model_tester.image_size // 4],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
def test_for_image_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in VAN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = VanModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
return image
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class VanModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_feature_extractor(self):
|
||||
return AutoFeatureExtractor.from_pretrained(VAN_PRETRAINED_MODEL_ARCHIVE_LIST[0])
|
||||
|
||||
@slow
|
||||
def test_inference_image_classification_head(self):
|
||||
model = VanForImageClassification.from_pretrained(VAN_PRETRAINED_MODEL_ARCHIVE_LIST[0]).to(torch_device)
|
||||
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 1000))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([0.1029, -0.0904, -0.6365]).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||
@@ -18,6 +18,7 @@ src/transformers/models/swin/modeling_swin.py
|
||||
src/transformers/models/convnext/modeling_convnext.py
|
||||
src/transformers/models/poolformer/modeling_poolformer.py
|
||||
src/transformers/models/vit_mae/modeling_vit_mae.py
|
||||
src/transformers/models/van/modeling_van.py
|
||||
src/transformers/models/segformer/modeling_segformer.py
|
||||
src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py
|
||||
src/transformers/models/bart/modeling_bart.py
|
||||
|
||||
Reference in New Issue
Block a user