From 7e7f743481abff9bcabdf73047dffb7c1db9d18b Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Wed, 31 Aug 2022 20:58:44 +0200 Subject: [PATCH] Add SegFormer ONNX support (#18006) * Add ONNX support * Make height and width dynamic axes Co-authored-by: Niels Rogge --- docs/source/en/serialization.mdx | 1 + src/transformers/models/segformer/__init__.py | 6 +++-- .../segformer/configuration_segformer.py | 26 +++++++++++++++++++ src/transformers/onnx/features.py | 6 +++++ tests/onnx/test_onnx_v2.py | 1 + 5 files changed, 38 insertions(+), 2 deletions(-) diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index d6bf15df7f..31ad430e06 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -90,6 +90,7 @@ Ready-made configurations include the following architectures: - ResNet - RoBERTa - RoFormer +- SegFormer - SqueezeBERT - T5 - ViT diff --git a/src/transformers/models/segformer/__init__.py b/src/transformers/models/segformer/__init__.py index 2317237509..7b8b60651d 100644 --- a/src/transformers/models/segformer/__init__.py +++ b/src/transformers/models/segformer/__init__.py @@ -26,7 +26,9 @@ from ...utils import ( ) -_import_structure = {"configuration_segformer": ["SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SegformerConfig"]} +_import_structure = { + "configuration_segformer": ["SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SegformerConfig", "SegformerOnnxConfig"] +} try: if not is_vision_available(): @@ -69,7 +71,7 @@ else: if TYPE_CHECKING: - from .configuration_segformer import SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, SegformerConfig + from .configuration_segformer import SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, SegformerConfig, SegformerOnnxConfig try: if not is_vision_available(): diff --git a/src/transformers/models/segformer/configuration_segformer.py b/src/transformers/models/segformer/configuration_segformer.py index faec5d6c4c..8b98af0fae 100644 --- a/src/transformers/models/segformer/configuration_segformer.py +++ b/src/transformers/models/segformer/configuration_segformer.py @@ -15,8 +15,13 @@ """ SegFormer model configuration""" import warnings +from collections import OrderedDict +from typing import Mapping + +from packaging import version from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -148,3 +153,24 @@ class SegformerConfig(PretrainedConfig): self.decoder_hidden_size = decoder_hidden_size self.reshape_last_stage = kwargs.get("reshape_last_stage", True) self.semantic_loss_ignore_index = semantic_loss_ignore_index + + +class SegformerOnnxConfig(OnnxConfig): + + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + @property + def default_onnx_opset(self) -> int: + return 12 diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index b1ea30c1af..535686f179 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -456,6 +456,12 @@ class FeaturesManager: "token-classification", onnx_config_cls="models.roformer.RoFormerOnnxConfig", ), + "segformer": supported_features_mapping( + "default", + "image-classification", + "semantic-segmentation", + onnx_config_cls="models.segformer.SegformerOnnxConfig", + ), "squeezebert": supported_features_mapping( "default", "masked-lm", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 16ee78a63c..f3c19ed8fa 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -216,6 +216,7 @@ PYTORCH_EXPORT_MODELS = { ("perceiver", "deepmind/vision-perceiver-conv", ("image-classification",)), ("longformer", "allenai/longformer-base-4096"), ("yolos", "hustvl/yolos-tiny"), + ("segformer", "nvidia/segformer-b0-finetuned-ade-512-512"), } PYTORCH_EXPORT_WITH_PAST_MODELS = {