Add Camembert to models exportable with ONNX (#14059)

Add Camembert to models exportable with ONNX

Co-authored-by: Thomas.Chaigneau <thomas.chaigneau@arkea.com>
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
This commit is contained in:
Thomas Chaigneau
2021-10-26 11:22:22 +02:00
committed by GitHub
parent 0c3174c758
commit 1f60df81b2
4 changed files with 27 additions and 2 deletions

View File

@@ -43,6 +43,7 @@ Ready-made configurations include the following models:
- ALBERT - ALBERT
- BART - BART
- BERT - BERT
- CamemBERT
- DistilBERT - DistilBERT
- GPT Neo - GPT Neo
- LayoutLM - LayoutLM

View File

@@ -28,7 +28,7 @@ from ...file_utils import (
_import_structure = { _import_structure = {
"configuration_camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig"], "configuration_camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig", "CamembertOnnxConfig"],
} }
if is_sentencepiece_available(): if is_sentencepiece_available():
@@ -62,7 +62,7 @@ if is_tf_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig, CamembertOnnxConfig
if is_sentencepiece_available(): if is_sentencepiece_available():
from .tokenization_camembert import CamembertTokenizer from .tokenization_camembert import CamembertTokenizer

View File

@@ -15,6 +15,10 @@
# limitations under the License. # limitations under the License.
""" CamemBERT configuration """ """ CamemBERT configuration """
from collections import OrderedDict
from typing import Mapping
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
from ..roberta.configuration_roberta import RobertaConfig from ..roberta.configuration_roberta import RobertaConfig
@@ -35,3 +39,14 @@ class CamembertConfig(RobertaConfig):
""" """
model_type = "camembert" model_type = "camembert"
class CamembertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)

View File

@@ -5,6 +5,7 @@ from .. import is_torch_available
from ..models.albert import AlbertOnnxConfig from ..models.albert import AlbertOnnxConfig
from ..models.bart import BartOnnxConfig from ..models.bart import BartOnnxConfig
from ..models.bert import BertOnnxConfig from ..models.bert import BertOnnxConfig
from ..models.camembert import CamembertOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig from ..models.distilbert import DistilBertOnnxConfig
from ..models.gpt2 import GPT2OnnxConfig from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig from ..models.gpt_neo import GPTNeoOnnxConfig
@@ -62,6 +63,14 @@ class FeaturesManager:
"bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig), "bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig),
"mbart": supported_features_mapping("default", onnx_config_cls=MBartOnnxConfig), "mbart": supported_features_mapping("default", onnx_config_cls=MBartOnnxConfig),
"bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig), "bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig),
"camembert": supported_features_mapping(
"default",
"causal-lm",
"sequence-classification",
"token-classification",
"question-answering",
onnx_config_cls=CamembertOnnxConfig,
),
"distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig), "distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig),
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig), "gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
"longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig), "longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig),