electra is added to onnx supported model (#15084)
* electra is added to onnx supported model * add google/electra-base-generator for test onnx module Co-authored-by: Lewis Tunstall <lewis.c.tunstall@gmail.com>
This commit is contained in:
@@ -50,6 +50,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- BERT
|
- BERT
|
||||||
- CamemBERT
|
- CamemBERT
|
||||||
- DistilBERT
|
- DistilBERT
|
||||||
|
- ELECTRA
|
||||||
- GPT Neo
|
- GPT Neo
|
||||||
- I-BERT
|
- I-BERT
|
||||||
- LayoutLM
|
- LayoutLM
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_to
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_electra": ["ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "ElectraConfig"],
|
"configuration_electra": ["ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "ElectraConfig", "ElectraOnnxConfig"],
|
||||||
"tokenization_electra": ["ElectraTokenizer"],
|
"tokenization_electra": ["ElectraTokenizer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,7 +71,7 @@ if is_flax_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
|
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig, ElectraOnnxConfig
|
||||||
from .tokenization_electra import ElectraTokenizer
|
from .tokenization_electra import ElectraTokenizer
|
||||||
|
|
||||||
if is_tokenizers_available():
|
if is_tokenizers_available():
|
||||||
|
|||||||
@@ -15,7 +15,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" ELECTRA model configuration"""
|
""" ELECTRA model configuration"""
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -170,3 +174,15 @@ class ElectraConfig(PretrainedConfig):
|
|||||||
self.position_embedding_type = position_embedding_type
|
self.position_embedding_type = position_embedding_type
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
self.classifier_dropout = classifier_dropout
|
self.classifier_dropout = classifier_dropout
|
||||||
|
|
||||||
|
|
||||||
|
class ElectraOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||||
|
("token_type_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from ..models.bart import BartOnnxConfig
|
|||||||
from ..models.bert import BertOnnxConfig
|
from ..models.bert import BertOnnxConfig
|
||||||
from ..models.camembert import CamembertOnnxConfig
|
from ..models.camembert import CamembertOnnxConfig
|
||||||
from ..models.distilbert import DistilBertOnnxConfig
|
from ..models.distilbert import DistilBertOnnxConfig
|
||||||
|
from ..models.electra import ElectraOnnxConfig
|
||||||
from ..models.gpt2 import GPT2OnnxConfig
|
from ..models.gpt2 import GPT2OnnxConfig
|
||||||
from ..models.gpt_neo import GPTNeoOnnxConfig
|
from ..models.gpt_neo import GPTNeoOnnxConfig
|
||||||
from ..models.ibert import IBertOnnxConfig
|
from ..models.ibert import IBertOnnxConfig
|
||||||
@@ -209,6 +210,15 @@ class FeaturesManager:
|
|||||||
"token-classification",
|
"token-classification",
|
||||||
onnx_config_cls=LayoutLMOnnxConfig,
|
onnx_config_cls=LayoutLMOnnxConfig,
|
||||||
),
|
),
|
||||||
|
"electra": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"masked-lm",
|
||||||
|
"causal-lm",
|
||||||
|
"sequence-classification",
|
||||||
|
"token-classification",
|
||||||
|
"question-answering",
|
||||||
|
onnx_config_cls=ElectraOnnxConfig,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))
|
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))
|
||||||
|
|||||||
@@ -174,6 +174,7 @@ PYTORCH_EXPORT_MODELS = {
|
|||||||
("ibert", "kssteven/ibert-roberta-base"),
|
("ibert", "kssteven/ibert-roberta-base"),
|
||||||
("camembert", "camembert-base"),
|
("camembert", "camembert-base"),
|
||||||
("distilbert", "distilbert-base-cased"),
|
("distilbert", "distilbert-base-cased"),
|
||||||
|
("electra", "google/electra-base-generator"),
|
||||||
("roberta", "roberta-base"),
|
("roberta", "roberta-base"),
|
||||||
("xlm-roberta", "xlm-roberta-base"),
|
("xlm-roberta", "xlm-roberta-base"),
|
||||||
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
||||||
|
|||||||
Reference in New Issue
Block a user