diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index f8e88b3501..4255b8f6e1 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -53,6 +53,7 @@ Ready-made configurations include the following architectures: - Blenderbot - BlenderbotSmall - CamemBERT +- ConvBERT - Data2VecText - Data2VecVision - DistilBERT diff --git a/src/transformers/models/convbert/__init__.py b/src/transformers/models/convbert/__init__.py index 77665acfea..d4f44482e0 100644 --- a/src/transformers/models/convbert/__init__.py +++ b/src/transformers/models/convbert/__init__.py @@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_t _import_structure = { - "configuration_convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig"], + "configuration_convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertOnnxConfig"], "tokenization_convbert": ["ConvBertTokenizer"], } @@ -58,7 +58,7 @@ if is_tf_available(): if TYPE_CHECKING: - from .configuration_convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig + from .configuration_convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig, ConvBertOnnxConfig from .tokenization_convbert import ConvBertTokenizer if is_tokenizers_available(): diff --git a/src/transformers/models/convbert/configuration_convbert.py b/src/transformers/models/convbert/configuration_convbert.py index 0f654eebb4..c424326b2b 100644 --- a/src/transformers/models/convbert/configuration_convbert.py +++ b/src/transformers/models/convbert/configuration_convbert.py @@ -14,7 +14,11 @@ # limitations under the License. """ ConvBERT model configuration""" +from collections import OrderedDict +from typing import Mapping + from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -138,3 +142,20 @@ class ConvBertConfig(PretrainedConfig): self.conv_kernel_size = conv_kernel_size self.num_groups = num_groups self.classifier_dropout = classifier_dropout + + +# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig +class ConvBertOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ("token_type_ids", dynamic_axis), + ] + ) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 2b65770eb8..9163de8c21 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -10,6 +10,7 @@ from ..models.big_bird import BigBirdOnnxConfig from ..models.blenderbot import BlenderbotOnnxConfig from ..models.blenderbot_small import BlenderbotSmallOnnxConfig from ..models.camembert import CamembertOnnxConfig +from ..models.convbert import ConvBertOnnxConfig from ..models.data2vec import Data2VecTextOnnxConfig from ..models.distilbert import DistilBertOnnxConfig from ..models.electra import ElectraOnnxConfig @@ -187,6 +188,15 @@ class FeaturesManager: "question-answering", onnx_config_cls=CamembertOnnxConfig, ), + "convbert": supported_features_mapping( + "default", + "masked-lm", + "sequence-classification", + "multiple-choice", + "token-classification", + "question-answering", + onnx_config_cls=ConvBertOnnxConfig, + ), "distilbert": supported_features_mapping( "default", "masked-lm", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 9e3ee73616..f85fe71218 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -175,6 +175,7 @@ PYTORCH_EXPORT_MODELS = { ("bigbird", "google/bigbird-roberta-base"), ("ibert", "kssteven/ibert-roberta-base"), ("camembert", "camembert-base"), + ("convbert", "YituTech/conv-bert-base"), ("distilbert", "distilbert-base-cased"), ("electra", "google/electra-base-generator"), ("roberta", "roberta-base"),