From 8c14b342aad0ff112a26102019a07bf23d2e33fd Mon Sep 17 00:00:00 2001 From: gcheron Date: Mon, 18 Jul 2022 15:17:07 +0200 Subject: [PATCH] add ONNX support for LeVit (#18154) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Guilhem Chéron --- docs/source/en/serialization.mdx | 1 + src/transformers/models/levit/__init__.py | 4 ++-- .../models/levit/configuration_levit.py | 24 +++++++++++++++++++ src/transformers/onnx/features.py | 3 +++ tests/onnx/test_onnx_v2.py | 1 + 5 files changed, 31 insertions(+), 2 deletions(-) diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index e376721cc1..e41ccae949 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -72,6 +72,7 @@ Ready-made configurations include the following architectures: - I-BERT - LayoutLM - LayoutLMv3 +- LeViT - LongT5 - M2M100 - Marian diff --git a/src/transformers/models/levit/__init__.py b/src/transformers/models/levit/__init__.py index bdbcaed41a..ea848f12a2 100644 --- a/src/transformers/models/levit/__init__.py +++ b/src/transformers/models/levit/__init__.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available -_import_structure = {"configuration_levit": ["LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LevitConfig"]} +_import_structure = {"configuration_levit": ["LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LevitConfig", "LevitOnnxConfig"]} try: if not is_vision_available(): @@ -46,7 +46,7 @@ else: if TYPE_CHECKING: - from .configuration_levit import LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, LevitConfig + from .configuration_levit import LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, LevitConfig, LevitOnnxConfig try: if not is_vision_available(): diff --git a/src/transformers/models/levit/configuration_levit.py b/src/transformers/models/levit/configuration_levit.py index 5d75b9fc23..a1113d7a75 100644 --- a/src/transformers/models/levit/configuration_levit.py +++ b/src/transformers/models/levit/configuration_levit.py @@ -14,7 +14,13 @@ # limitations under the License. """ LeViT model configuration""" +from collections import OrderedDict +from typing import Mapping + +from packaging import version + from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -120,3 +126,21 @@ class LevitConfig(PretrainedConfig): ["Subsample", key_dim[0], hidden_sizes[0] // key_dim[0], 4, 2, 2], ["Subsample", key_dim[0], hidden_sizes[1] // key_dim[0], 4, 2, 2], ] + + +# Copied from transformers.models.vit.configuration_vit.ViTOnnxConfig +class LevitOnnxConfig(OnnxConfig): + + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 35c52ed603..e7c24a8ad9 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -333,6 +333,9 @@ class FeaturesManager: "token-classification", onnx_config_cls="models.layoutlmv3.LayoutLMv3OnnxConfig", ), + "levit": supported_features_mapping( + "default", "image-classification", onnx_config_cls="models.levit.LevitOnnxConfig" + ), "longt5": supported_features_mapping( "default", "default-with-past", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 3be3a34d9c..6b22dc3420 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -196,6 +196,7 @@ PYTORCH_EXPORT_MODELS = { ("xlm-roberta", "xlm-roberta-base"), ("layoutlm", "microsoft/layoutlm-base-uncased"), ("layoutlmv3", "microsoft/layoutlmv3-base"), + ("levit", "facebook/levit-128S"), ("vit", "google/vit-base-patch16-224"), ("deit", "facebook/deit-small-patch16-224"), ("beit", "microsoft/beit-base-patch16-224"),