Add OnnxConfig for ConvBERT (#16859)

* add OnnxConfig for ConvBert

Co-authored-by: ChainYo <t.chaigneau.tc@gmail.com>
This commit is contained in:
Thomas Chaigneau
2022-04-22 18:19:15 +02:00
committed by GitHub
parent 0d1cff1195
commit ec81c11a18
5 changed files with 35 additions and 2 deletions

View File

@@ -53,6 +53,7 @@ Ready-made configurations include the following architectures:
- Blenderbot - Blenderbot
- BlenderbotSmall - BlenderbotSmall
- CamemBERT - CamemBERT
- ConvBERT
- Data2VecText - Data2VecText
- Data2VecVision - Data2VecVision
- DistilBERT - DistilBERT

View File

@@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_t
_import_structure = { _import_structure = {
"configuration_convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig"], "configuration_convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertOnnxConfig"],
"tokenization_convbert": ["ConvBertTokenizer"], "tokenization_convbert": ["ConvBertTokenizer"],
} }
@@ -58,7 +58,7 @@ if is_tf_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig from .configuration_convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig, ConvBertOnnxConfig
from .tokenization_convbert import ConvBertTokenizer from .tokenization_convbert import ConvBertTokenizer
if is_tokenizers_available(): if is_tokenizers_available():

View File

@@ -14,7 +14,11 @@
# limitations under the License. # limitations under the License.
""" ConvBERT model configuration""" """ ConvBERT model configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
@@ -138,3 +142,20 @@ class ConvBertConfig(PretrainedConfig):
self.conv_kernel_size = conv_kernel_size self.conv_kernel_size = conv_kernel_size
self.num_groups = num_groups self.num_groups = num_groups
self.classifier_dropout = classifier_dropout self.classifier_dropout = classifier_dropout
# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig
class ConvBertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)

View File

@@ -10,6 +10,7 @@ from ..models.big_bird import BigBirdOnnxConfig
from ..models.blenderbot import BlenderbotOnnxConfig from ..models.blenderbot import BlenderbotOnnxConfig
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
from ..models.camembert import CamembertOnnxConfig from ..models.camembert import CamembertOnnxConfig
from ..models.convbert import ConvBertOnnxConfig
from ..models.data2vec import Data2VecTextOnnxConfig from ..models.data2vec import Data2VecTextOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig from ..models.distilbert import DistilBertOnnxConfig
from ..models.electra import ElectraOnnxConfig from ..models.electra import ElectraOnnxConfig
@@ -187,6 +188,15 @@ class FeaturesManager:
"question-answering", "question-answering",
onnx_config_cls=CamembertOnnxConfig, onnx_config_cls=CamembertOnnxConfig,
), ),
"convbert": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=ConvBertOnnxConfig,
),
"distilbert": supported_features_mapping( "distilbert": supported_features_mapping(
"default", "default",
"masked-lm", "masked-lm",

View File

@@ -175,6 +175,7 @@ PYTORCH_EXPORT_MODELS = {
("bigbird", "google/bigbird-roberta-base"), ("bigbird", "google/bigbird-roberta-base"),
("ibert", "kssteven/ibert-roberta-base"), ("ibert", "kssteven/ibert-roberta-base"),
("camembert", "camembert-base"), ("camembert", "camembert-base"),
("convbert", "YituTech/conv-bert-base"),
("distilbert", "distilbert-base-cased"), ("distilbert", "distilbert-base-cased"),
("electra", "google/electra-base-generator"), ("electra", "google/electra-base-generator"),
("roberta", "roberta-base"), ("roberta", "roberta-base"),