From e0be053e433d641e3a2965da114b618b4bfdbf1a Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Thu, 9 Jun 2022 15:31:02 +0200 Subject: [PATCH] Add ONNX support for ConvNeXT (#17627) --- docs/source/en/serialization.mdx | 1 + src/transformers/models/convnext/__init__.py | 6 +++-- .../models/convnext/configuration_convnext.py | 23 +++++++++++++++++++ src/transformers/onnx/features.py | 5 ++++ tests/onnx/test_onnx_v2.py | 1 + 5 files changed, 34 insertions(+), 2 deletions(-) diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index bf172bd199..d6579ec239 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -55,6 +55,7 @@ Ready-made configurations include the following architectures: - BlenderbotSmall - CamemBERT - ConvBERT +- ConvNeXT - Data2VecText - Data2VecVision - DeiT diff --git a/src/transformers/models/convnext/__init__.py b/src/transformers/models/convnext/__init__.py index 37873982b7..93000d5c66 100644 --- a/src/transformers/models/convnext/__init__.py +++ b/src/transformers/models/convnext/__init__.py @@ -27,7 +27,9 @@ from ...utils import ( ) -_import_structure = {"configuration_convnext": ["CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvNextConfig"]} +_import_structure = { + "configuration_convnext": ["CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvNextConfig", "ConvNextOnnxConfig"] +} try: if not is_vision_available(): @@ -63,7 +65,7 @@ else: ] if TYPE_CHECKING: - from .configuration_convnext import CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextConfig + from .configuration_convnext import CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextConfig, ConvNextOnnxConfig try: if not is_vision_available(): diff --git a/src/transformers/models/convnext/configuration_convnext.py b/src/transformers/models/convnext/configuration_convnext.py index 74067ad337..9f77c00992 100644 --- a/src/transformers/models/convnext/configuration_convnext.py +++ b/src/transformers/models/convnext/configuration_convnext.py @@ -14,7 +14,13 @@ # limitations under the License. """ ConvNeXT model configuration""" +from collections import OrderedDict +from typing import Mapping + +from packaging import version + from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -101,3 +107,20 @@ class ConvNextConfig(PretrainedConfig): self.layer_scale_init_value = layer_scale_init_value self.drop_path_rate = drop_path_rate self.image_size = image_size + + +class ConvNextOnnxConfig(OnnxConfig): + + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-5 diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index d29831c36b..a54bb84712 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -193,6 +193,11 @@ class FeaturesManager: "question-answering", onnx_config_cls="models.convbert.ConvBertOnnxConfig", ), + "convnext": supported_features_mapping( + "default", + "image-classification", + onnx_config_cls="models.convnext.ConvNextOnnxConfig", + ), "data2vec-text": supported_features_mapping( "default", "masked-lm", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 2f73169def..83ba77b491 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -180,6 +180,7 @@ PYTORCH_EXPORT_MODELS = { ("ibert", "kssteven/ibert-roberta-base"), ("camembert", "camembert-base"), ("convbert", "YituTech/conv-bert-base"), + ("convnext", "facebook/convnext-tiny-224"), ("distilbert", "distilbert-base-cased"), ("electra", "google/electra-base-generator"), ("resnet", "microsoft/resnet-50"),