Add OnnxConfig for ConvBERT (#16859)
* add OnnxConfig for ConvBert Co-authored-by: ChainYo <t.chaigneau.tc@gmail.com>
This commit is contained in:
@@ -53,6 +53,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- Blenderbot
|
- Blenderbot
|
||||||
- BlenderbotSmall
|
- BlenderbotSmall
|
||||||
- CamemBERT
|
- CamemBERT
|
||||||
|
- ConvBERT
|
||||||
- Data2VecText
|
- Data2VecText
|
||||||
- Data2VecVision
|
- Data2VecVision
|
||||||
- DistilBERT
|
- DistilBERT
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_t
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig"],
|
"configuration_convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertOnnxConfig"],
|
||||||
"tokenization_convbert": ["ConvBertTokenizer"],
|
"tokenization_convbert": ["ConvBertTokenizer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -58,7 +58,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig
|
from .configuration_convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig, ConvBertOnnxConfig
|
||||||
from .tokenization_convbert import ConvBertTokenizer
|
from .tokenization_convbert import ConvBertTokenizer
|
||||||
|
|
||||||
if is_tokenizers_available():
|
if is_tokenizers_available():
|
||||||
|
|||||||
@@ -14,7 +14,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" ConvBERT model configuration"""
|
""" ConvBERT model configuration"""
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -138,3 +142,20 @@ class ConvBertConfig(PretrainedConfig):
|
|||||||
self.conv_kernel_size = conv_kernel_size
|
self.conv_kernel_size = conv_kernel_size
|
||||||
self.num_groups = num_groups
|
self.num_groups = num_groups
|
||||||
self.classifier_dropout = classifier_dropout
|
self.classifier_dropout = classifier_dropout
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig
|
||||||
|
class ConvBertOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task == "multiple-choice":
|
||||||
|
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
||||||
|
else:
|
||||||
|
dynamic_axis = {0: "batch", 1: "sequence"}
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", dynamic_axis),
|
||||||
|
("attention_mask", dynamic_axis),
|
||||||
|
("token_type_ids", dynamic_axis),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from ..models.big_bird import BigBirdOnnxConfig
|
|||||||
from ..models.blenderbot import BlenderbotOnnxConfig
|
from ..models.blenderbot import BlenderbotOnnxConfig
|
||||||
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
|
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
|
||||||
from ..models.camembert import CamembertOnnxConfig
|
from ..models.camembert import CamembertOnnxConfig
|
||||||
|
from ..models.convbert import ConvBertOnnxConfig
|
||||||
from ..models.data2vec import Data2VecTextOnnxConfig
|
from ..models.data2vec import Data2VecTextOnnxConfig
|
||||||
from ..models.distilbert import DistilBertOnnxConfig
|
from ..models.distilbert import DistilBertOnnxConfig
|
||||||
from ..models.electra import ElectraOnnxConfig
|
from ..models.electra import ElectraOnnxConfig
|
||||||
@@ -187,6 +188,15 @@ class FeaturesManager:
|
|||||||
"question-answering",
|
"question-answering",
|
||||||
onnx_config_cls=CamembertOnnxConfig,
|
onnx_config_cls=CamembertOnnxConfig,
|
||||||
),
|
),
|
||||||
|
"convbert": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"masked-lm",
|
||||||
|
"sequence-classification",
|
||||||
|
"multiple-choice",
|
||||||
|
"token-classification",
|
||||||
|
"question-answering",
|
||||||
|
onnx_config_cls=ConvBertOnnxConfig,
|
||||||
|
),
|
||||||
"distilbert": supported_features_mapping(
|
"distilbert": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
|
|||||||
@@ -175,6 +175,7 @@ PYTORCH_EXPORT_MODELS = {
|
|||||||
("bigbird", "google/bigbird-roberta-base"),
|
("bigbird", "google/bigbird-roberta-base"),
|
||||||
("ibert", "kssteven/ibert-roberta-base"),
|
("ibert", "kssteven/ibert-roberta-base"),
|
||||||
("camembert", "camembert-base"),
|
("camembert", "camembert-base"),
|
||||||
|
("convbert", "YituTech/conv-bert-base"),
|
||||||
("distilbert", "distilbert-base-cased"),
|
("distilbert", "distilbert-base-cased"),
|
||||||
("electra", "google/electra-base-generator"),
|
("electra", "google/electra-base-generator"),
|
||||||
("roberta", "roberta-base"),
|
("roberta", "roberta-base"),
|
||||||
|
|||||||
Reference in New Issue
Block a user