electra is added to onnx supported model (#15084)

* electra is added to onnx supported model

* add google/electra-base-generator for test onnx module

Co-authored-by: Lewis Tunstall <lewis.c.tunstall@gmail.com>
This commit is contained in:
aaron
2022-02-08 06:47:49 -08:00
committed by GitHub
parent 0fe17f375a
commit 87d08afb16
5 changed files with 30 additions and 2 deletions

View File

@@ -50,6 +50,7 @@ Ready-made configurations include the following architectures:
- BERT - BERT
- CamemBERT - CamemBERT
- DistilBERT - DistilBERT
- ELECTRA
- GPT Neo - GPT Neo
- I-BERT - I-BERT
- LayoutLM - LayoutLM

View File

@@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_to
_import_structure = { _import_structure = {
"configuration_electra": ["ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "ElectraConfig"], "configuration_electra": ["ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "ElectraConfig", "ElectraOnnxConfig"],
"tokenization_electra": ["ElectraTokenizer"], "tokenization_electra": ["ElectraTokenizer"],
} }
@@ -71,7 +71,7 @@ if is_flax_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig, ElectraOnnxConfig
from .tokenization_electra import ElectraTokenizer from .tokenization_electra import ElectraTokenizer
if is_tokenizers_available(): if is_tokenizers_available():

View File

@@ -15,7 +15,11 @@
# limitations under the License. # limitations under the License.
""" ELECTRA model configuration""" """ ELECTRA model configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
@@ -170,3 +174,15 @@ class ElectraConfig(PretrainedConfig):
self.position_embedding_type = position_embedding_type self.position_embedding_type = position_embedding_type
self.use_cache = use_cache self.use_cache = use_cache
self.classifier_dropout = classifier_dropout self.classifier_dropout = classifier_dropout
class ElectraOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("token_type_ids", {0: "batch", 1: "sequence"}),
]
)

View File

@@ -7,6 +7,7 @@ from ..models.bart import BartOnnxConfig
from ..models.bert import BertOnnxConfig from ..models.bert import BertOnnxConfig
from ..models.camembert import CamembertOnnxConfig from ..models.camembert import CamembertOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig from ..models.distilbert import DistilBertOnnxConfig
from ..models.electra import ElectraOnnxConfig
from ..models.gpt2 import GPT2OnnxConfig from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.ibert import IBertOnnxConfig from ..models.ibert import IBertOnnxConfig
@@ -209,6 +210,15 @@ class FeaturesManager:
"token-classification", "token-classification",
onnx_config_cls=LayoutLMOnnxConfig, onnx_config_cls=LayoutLMOnnxConfig,
), ),
"electra": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"token-classification",
"question-answering",
onnx_config_cls=ElectraOnnxConfig,
),
} }
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values()))) AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))

View File

@@ -174,6 +174,7 @@ PYTORCH_EXPORT_MODELS = {
("ibert", "kssteven/ibert-roberta-base"), ("ibert", "kssteven/ibert-roberta-base"),
("camembert", "camembert-base"), ("camembert", "camembert-base"),
("distilbert", "distilbert-base-cased"), ("distilbert", "distilbert-base-cased"),
("electra", "google/electra-base-generator"),
("roberta", "roberta-base"), ("roberta", "roberta-base"),
("xlm-roberta", "xlm-roberta-base"), ("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"), ("layoutlm", "microsoft/layoutlm-base-uncased"),