add ONNX support for swin transformer (#19390)

* swin transformer onnx support

* Updated image dimensions as dynamic

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
Bibhabasu Mohapatra
2022-10-07 18:53:24 +05:30
committed by GitHub
parent 969534af4b
commit e162cebfa3
5 changed files with 30 additions and 2 deletions

View File

@@ -94,6 +94,7 @@ Ready-made configurations include the following architectures:
- RoFormer
- SegFormer
- SqueezeBERT
- Swin Transformer
- T5
- ViT
- XLM

View File

@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = {"configuration_swin": ["SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwinConfig"]}
_import_structure = {"configuration_swin": ["SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwinConfig", "SwinOnnxConfig"]}
try:
@@ -53,7 +53,7 @@ else:
]
if TYPE_CHECKING:
from .configuration_swin import SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, SwinConfig
from .configuration_swin import SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, SwinConfig, SwinOnnxConfig
try:
if not is_torch_available():

View File

@@ -14,7 +14,13 @@
# limitations under the License.
""" Swin Transformer model configuration"""
from collections import OrderedDict
from typing import Mapping
from packaging import version
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
@@ -145,3 +151,20 @@ class SwinConfig(PretrainedConfig):
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
# this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
class SwinOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-4

View File

@@ -471,6 +471,9 @@ class FeaturesManager:
"question-answering",
onnx_config_cls="models.squeezebert.SqueezeBertOnnxConfig",
),
"swin": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls="models.swin.SwinOnnxConfig"
),
"t5": supported_features_mapping(
"default",
"default-with-past",

View File

@@ -217,6 +217,7 @@ PYTORCH_EXPORT_MODELS = {
("longformer", "allenai/longformer-base-4096"),
("yolos", "hustvl/yolos-tiny"),
("segformer", "nvidia/segformer-b0-finetuned-ade-512-512"),
("swin", "microsoft/swin-tiny-patch4-window7-224"),
}
PYTORCH_EXPORT_WITH_PAST_MODELS = {