Add SegFormer ONNX support (#18006)
* Add ONNX support * Make height and width dynamic axes Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -90,6 +90,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- ResNet
|
- ResNet
|
||||||
- RoBERTa
|
- RoBERTa
|
||||||
- RoFormer
|
- RoFormer
|
||||||
|
- SegFormer
|
||||||
- SqueezeBERT
|
- SqueezeBERT
|
||||||
- T5
|
- T5
|
||||||
- ViT
|
- ViT
|
||||||
|
|||||||
@@ -26,7 +26,9 @@ from ...utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {"configuration_segformer": ["SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SegformerConfig"]}
|
_import_structure = {
|
||||||
|
"configuration_segformer": ["SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SegformerConfig", "SegformerOnnxConfig"]
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_vision_available():
|
if not is_vision_available():
|
||||||
@@ -69,7 +71,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_segformer import SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, SegformerConfig
|
from .configuration_segformer import SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, SegformerConfig, SegformerOnnxConfig
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_vision_available():
|
if not is_vision_available():
|
||||||
|
|||||||
@@ -15,8 +15,13 @@
|
|||||||
""" SegFormer model configuration"""
|
""" SegFormer model configuration"""
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -148,3 +153,24 @@ class SegformerConfig(PretrainedConfig):
|
|||||||
self.decoder_hidden_size = decoder_hidden_size
|
self.decoder_hidden_size = decoder_hidden_size
|
||||||
self.reshape_last_stage = kwargs.get("reshape_last_stage", True)
|
self.reshape_last_stage = kwargs.get("reshape_last_stage", True)
|
||||||
self.semantic_loss_ignore_index = semantic_loss_ignore_index
|
self.semantic_loss_ignore_index = semantic_loss_ignore_index
|
||||||
|
|
||||||
|
|
||||||
|
class SegformerOnnxConfig(OnnxConfig):
|
||||||
|
|
||||||
|
torch_onnx_minimum_version = version.parse("1.11")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def atol_for_validation(self) -> float:
|
||||||
|
return 1e-4
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_onnx_opset(self) -> int:
|
||||||
|
return 12
|
||||||
|
|||||||
@@ -456,6 +456,12 @@ class FeaturesManager:
|
|||||||
"token-classification",
|
"token-classification",
|
||||||
onnx_config_cls="models.roformer.RoFormerOnnxConfig",
|
onnx_config_cls="models.roformer.RoFormerOnnxConfig",
|
||||||
),
|
),
|
||||||
|
"segformer": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"image-classification",
|
||||||
|
"semantic-segmentation",
|
||||||
|
onnx_config_cls="models.segformer.SegformerOnnxConfig",
|
||||||
|
),
|
||||||
"squeezebert": supported_features_mapping(
|
"squeezebert": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
|
|||||||
@@ -216,6 +216,7 @@ PYTORCH_EXPORT_MODELS = {
|
|||||||
("perceiver", "deepmind/vision-perceiver-conv", ("image-classification",)),
|
("perceiver", "deepmind/vision-perceiver-conv", ("image-classification",)),
|
||||||
("longformer", "allenai/longformer-base-4096"),
|
("longformer", "allenai/longformer-base-4096"),
|
||||||
("yolos", "hustvl/yolos-tiny"),
|
("yolos", "hustvl/yolos-tiny"),
|
||||||
|
("segformer", "nvidia/segformer-b0-finetuned-ade-512-512"),
|
||||||
}
|
}
|
||||||
|
|
||||||
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
||||||
|
|||||||
Reference in New Issue
Block a user