From 9de70f213eb234522095cc9af7b2fac53afc2d87 Mon Sep 17 00:00:00 2001 From: Jim Rohrer Date: Fri, 1 Apr 2022 03:52:42 -0500 Subject: [PATCH] Add ONNX export for BeiT (#16498) * Add beit onnx conversion support * Updated docs * Added cross reference to ViT ONNX config --- docs/source/serialization.mdx | 1 + src/transformers/models/beit/__init__.py | 4 ++-- .../models/beit/configuration_beit.py | 23 +++++++++++++++++++ src/transformers/onnx/features.py | 2 ++ tests/onnx/test_onnx_v2.py | 6 ++--- 5 files changed, 31 insertions(+), 5 deletions(-) diff --git a/docs/source/serialization.mdx b/docs/source/serialization.mdx index fc969aac4f..65fb5fa5cc 100644 --- a/docs/source/serialization.mdx +++ b/docs/source/serialization.mdx @@ -47,6 +47,7 @@ Ready-made configurations include the following architectures: - ALBERT - BART +- BEiT - BERT - Blenderbot - BlenderbotSmall diff --git a/src/transformers/models/beit/__init__.py b/src/transformers/models/beit/__init__.py index 319fb2880a..27c31775d3 100644 --- a/src/transformers/models/beit/__init__.py +++ b/src/transformers/models/beit/__init__.py @@ -22,7 +22,7 @@ from ...utils import _LazyModule, is_flax_available, is_torch_available, is_visi _import_structure = { - "configuration_beit": ["BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BeitConfig"], + "configuration_beit": ["BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BeitConfig", "BeitOnnxConfig"], } if is_vision_available(): @@ -48,7 +48,7 @@ if is_flax_available(): ] if TYPE_CHECKING: - from .configuration_beit import BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BeitConfig + from .configuration_beit import BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BeitConfig, BeitOnnxConfig if is_vision_available(): from .feature_extraction_beit import BeitFeatureExtractor diff --git a/src/transformers/models/beit/configuration_beit.py b/src/transformers/models/beit/configuration_beit.py index 9a1dfa8c20..7c47aba0c2 100644 --- a/src/transformers/models/beit/configuration_beit.py +++ b/src/transformers/models/beit/configuration_beit.py @@ -13,8 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """ BEiT model configuration""" +from collections import OrderedDict +from typing import Mapping + +from packaging import version from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -176,3 +181,21 @@ class BeitConfig(PretrainedConfig): self.auxiliary_num_convs = auxiliary_num_convs self.auxiliary_concat_input = auxiliary_concat_input self.semantic_loss_ignore_index = semantic_loss_ignore_index + + +# Copied from transformers.models.vit.configuration_vit.ViTOnnxConfig +class BeitOnnxConfig(OnnxConfig): + + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 926137c594..cf5e55c521 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -4,6 +4,7 @@ from typing import Callable, Dict, Optional, Tuple, Type, Union from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_available, is_torch_available from ..models.albert import AlbertOnnxConfig from ..models.bart import BartOnnxConfig +from ..models.beit import BeitOnnxConfig from ..models.bert import BertOnnxConfig from ..models.blenderbot import BlenderbotOnnxConfig from ..models.blenderbot_small import BlenderbotSmallOnnxConfig @@ -270,6 +271,7 @@ class FeaturesManager: onnx_config_cls=ElectraOnnxConfig, ), "vit": supported_features_mapping("default", "image-classification", onnx_config_cls=ViTOnnxConfig), + "beit": supported_features_mapping("default", "image-classification", onnx_config_cls=BeitOnnxConfig), "blenderbot": supported_features_mapping( "default", "default-with-past", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index f530515aed..ba8d51158f 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -15,14 +15,13 @@ from transformers.onnx import ( export, validate_model_outputs, ) +from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size +from transformers.testing_utils import require_onnx, require_tf, require_torch, require_vision, slow if is_torch_available() or is_tf_available(): from transformers.onnx.features import FeaturesManager -from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size -from transformers.testing_utils import require_onnx, require_tf, require_torch, require_vision, slow - @require_onnx class OnnxUtilsTestCaseV2(TestCase): @@ -181,6 +180,7 @@ PYTORCH_EXPORT_MODELS = { ("xlm-roberta", "xlm-roberta-base"), ("layoutlm", "microsoft/layoutlm-base-uncased"), ("vit", "google/vit-base-patch16-224"), + ("beit", "microsoft/beit-base-patch16-224"), } PYTORCH_EXPORT_WITH_PAST_MODELS = {