From 15bc776fec90f9761849ea34585c052caa40f744 Mon Sep 17 00:00:00 2001 From: Syed Abdul Gaffar Shakhadri Date: Fri, 23 Dec 2022 12:00:57 +0530 Subject: [PATCH] Add Onnx Config for PoolFormer (#20868) poolformer onnx Co-authored-by: syed --- docs/source/en/serialization.mdx | 1 + .../models/poolformer/__init__.py | 14 ++++++++++-- .../poolformer/configuration_poolformer.py | 22 +++++++++++++++++++ src/transformers/onnx/features.py | 3 +++ tests/onnx/test_onnx_v2.py | 1 + 5 files changed, 39 insertions(+), 2 deletions(-) diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index 88265240e3..7079a91f40 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -102,6 +102,7 @@ Ready-made configurations include the following architectures: - OWL-ViT - Perceiver - PLBart +- PoolFormer - RemBERT - ResNet - RoBERTa diff --git a/src/transformers/models/poolformer/__init__.py b/src/transformers/models/poolformer/__init__.py index 947ac753fb..79e82a2280 100644 --- a/src/transformers/models/poolformer/__init__.py +++ b/src/transformers/models/poolformer/__init__.py @@ -21,7 +21,13 @@ from typing import TYPE_CHECKING from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available -_import_structure = {"configuration_poolformer": ["POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PoolFormerConfig"]} +_import_structure = { + "configuration_poolformer": [ + "POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", + "PoolFormerConfig", + "PoolFormerOnnxConfig", + ] +} try: if not is_vision_available(): @@ -47,7 +53,11 @@ else: if TYPE_CHECKING: - from .configuration_poolformer import POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, PoolFormerConfig + from .configuration_poolformer import ( + POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, + PoolFormerConfig, + PoolFormerOnnxConfig, + ) try: if not is_vision_available(): diff --git a/src/transformers/models/poolformer/configuration_poolformer.py b/src/transformers/models/poolformer/configuration_poolformer.py index 12f0cce373..c55f13b80c 100644 --- a/src/transformers/models/poolformer/configuration_poolformer.py +++ b/src/transformers/models/poolformer/configuration_poolformer.py @@ -13,8 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PoolFormer model configuration""" +from collections import OrderedDict +from typing import Mapping + +from packaging import version from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -125,3 +130,20 @@ class PoolFormerConfig(PretrainedConfig): self.layer_scale_init_value = layer_scale_init_value self.initializer_range = initializer_range super().__init__(**kwargs) + + +class PoolFormerOnnxConfig(OnnxConfig): + + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 2e-3 diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index b371844c9d..ff82bf60b3 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -447,6 +447,9 @@ class FeaturesManager: "sequence-classification", onnx_config_cls="models.perceiver.PerceiverOnnxConfig", ), + "poolformer": supported_features_mapping( + "default", "image-classification", onnx_config_cls="models.poolformer.PoolFormerOnnxConfig" + ), "rembert": supported_features_mapping( "default", "masked-lm", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index afe4c1c036..e7a0e15d24 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -210,6 +210,7 @@ PYTORCH_EXPORT_MODELS = { ("owlvit", "google/owlvit-base-patch32"), ("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("masked-lm", "sequence-classification")), ("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("image-classification",)), + ("poolformer", "sail/poolformer_s12"), ("rembert", "google/rembert"), ("resnet", "microsoft/resnet-50"), ("roberta", "hf-internal-testing/tiny-random-RobertaModel"),