From 87d08afb169693a679c519024939e296f751a391 Mon Sep 17 00:00:00 2001 From: aaron <96719527+arron1227@users.noreply.github.com> Date: Tue, 8 Feb 2022 06:47:49 -0800 Subject: [PATCH] electra is added to onnx supported model (#15084) * electra is added to onnx supported model * add google/electra-base-generator for test onnx module Co-authored-by: Lewis Tunstall --- docs/source/serialization.mdx | 1 + src/transformers/models/electra/__init__.py | 4 ++-- .../models/electra/configuration_electra.py | 16 ++++++++++++++++ src/transformers/onnx/features.py | 10 ++++++++++ tests/test_onnx_v2.py | 1 + 5 files changed, 30 insertions(+), 2 deletions(-) diff --git a/docs/source/serialization.mdx b/docs/source/serialization.mdx index b985fabaa2..2be3fc3735 100644 --- a/docs/source/serialization.mdx +++ b/docs/source/serialization.mdx @@ -50,6 +50,7 @@ Ready-made configurations include the following architectures: - BERT - CamemBERT - DistilBERT +- ELECTRA - GPT Neo - I-BERT - LayoutLM diff --git a/src/transformers/models/electra/__init__.py b/src/transformers/models/electra/__init__.py index 1aad02a412..a8f1ec5db7 100644 --- a/src/transformers/models/electra/__init__.py +++ b/src/transformers/models/electra/__init__.py @@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_to _import_structure = { - "configuration_electra": ["ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "ElectraConfig"], + "configuration_electra": ["ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "ElectraConfig", "ElectraOnnxConfig"], "tokenization_electra": ["ElectraTokenizer"], } @@ -71,7 +71,7 @@ if is_flax_available(): if TYPE_CHECKING: - from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig + from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig, ElectraOnnxConfig from .tokenization_electra import ElectraTokenizer if is_tokenizers_available(): diff --git a/src/transformers/models/electra/configuration_electra.py b/src/transformers/models/electra/configuration_electra.py index 43b3b3255d..9b4525d3dc 100644 --- a/src/transformers/models/electra/configuration_electra.py +++ b/src/transformers/models/electra/configuration_electra.py @@ -15,7 +15,11 @@ # limitations under the License. """ ELECTRA model configuration""" +from collections import OrderedDict +from typing import Mapping + from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -170,3 +174,15 @@ class ElectraConfig(PretrainedConfig): self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.classifier_dropout = classifier_dropout + + +class ElectraOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ("token_type_ids", {0: "batch", 1: "sequence"}), + ] + ) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 1020387592..9f3ad05b4f 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -7,6 +7,7 @@ from ..models.bart import BartOnnxConfig from ..models.bert import BertOnnxConfig from ..models.camembert import CamembertOnnxConfig from ..models.distilbert import DistilBertOnnxConfig +from ..models.electra import ElectraOnnxConfig from ..models.gpt2 import GPT2OnnxConfig from ..models.gpt_neo import GPTNeoOnnxConfig from ..models.ibert import IBertOnnxConfig @@ -209,6 +210,15 @@ class FeaturesManager: "token-classification", onnx_config_cls=LayoutLMOnnxConfig, ), + "electra": supported_features_mapping( + "default", + "masked-lm", + "causal-lm", + "sequence-classification", + "token-classification", + "question-answering", + onnx_config_cls=ElectraOnnxConfig, + ), } AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values()))) diff --git a/tests/test_onnx_v2.py b/tests/test_onnx_v2.py index c159431027..8bd88b7445 100644 --- a/tests/test_onnx_v2.py +++ b/tests/test_onnx_v2.py @@ -174,6 +174,7 @@ PYTORCH_EXPORT_MODELS = { ("ibert", "kssteven/ibert-roberta-base"), ("camembert", "camembert-base"), ("distilbert", "distilbert-base-cased"), + ("electra", "google/electra-base-generator"), ("roberta", "roberta-base"), ("xlm-roberta", "xlm-roberta-base"), ("layoutlm", "microsoft/layoutlm-base-uncased"),