Add Camembert to models exportable with ONNX (#14059)
Add Camembert to models exportable with ONNX Co-authored-by: Thomas.Chaigneau <thomas.chaigneau@arkea.com> Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
This commit is contained in:
@@ -43,6 +43,7 @@ Ready-made configurations include the following models:
|
|||||||
- ALBERT
|
- ALBERT
|
||||||
- BART
|
- BART
|
||||||
- BERT
|
- BERT
|
||||||
|
- CamemBERT
|
||||||
- DistilBERT
|
- DistilBERT
|
||||||
- GPT Neo
|
- GPT Neo
|
||||||
- LayoutLM
|
- LayoutLM
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ from ...file_utils import (
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig"],
|
"configuration_camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig", "CamembertOnnxConfig"],
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_sentencepiece_available():
|
if is_sentencepiece_available():
|
||||||
@@ -62,7 +62,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig, CamembertOnnxConfig
|
||||||
|
|
||||||
if is_sentencepiece_available():
|
if is_sentencepiece_available():
|
||||||
from .tokenization_camembert import CamembertTokenizer
|
from .tokenization_camembert import CamembertTokenizer
|
||||||
|
|||||||
@@ -15,6 +15,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" CamemBERT configuration """
|
""" CamemBERT configuration """
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..roberta.configuration_roberta import RobertaConfig
|
from ..roberta.configuration_roberta import RobertaConfig
|
||||||
|
|
||||||
@@ -35,3 +39,14 @@ class CamembertConfig(RobertaConfig):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
model_type = "camembert"
|
model_type = "camembert"
|
||||||
|
|
||||||
|
|
||||||
|
class CamembertOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from .. import is_torch_available
|
|||||||
from ..models.albert import AlbertOnnxConfig
|
from ..models.albert import AlbertOnnxConfig
|
||||||
from ..models.bart import BartOnnxConfig
|
from ..models.bart import BartOnnxConfig
|
||||||
from ..models.bert import BertOnnxConfig
|
from ..models.bert import BertOnnxConfig
|
||||||
|
from ..models.camembert import CamembertOnnxConfig
|
||||||
from ..models.distilbert import DistilBertOnnxConfig
|
from ..models.distilbert import DistilBertOnnxConfig
|
||||||
from ..models.gpt2 import GPT2OnnxConfig
|
from ..models.gpt2 import GPT2OnnxConfig
|
||||||
from ..models.gpt_neo import GPTNeoOnnxConfig
|
from ..models.gpt_neo import GPTNeoOnnxConfig
|
||||||
@@ -62,6 +63,14 @@ class FeaturesManager:
|
|||||||
"bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig),
|
"bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig),
|
||||||
"mbart": supported_features_mapping("default", onnx_config_cls=MBartOnnxConfig),
|
"mbart": supported_features_mapping("default", onnx_config_cls=MBartOnnxConfig),
|
||||||
"bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig),
|
"bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig),
|
||||||
|
"camembert": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"causal-lm",
|
||||||
|
"sequence-classification",
|
||||||
|
"token-classification",
|
||||||
|
"question-answering",
|
||||||
|
onnx_config_cls=CamembertOnnxConfig,
|
||||||
|
),
|
||||||
"distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig),
|
"distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig),
|
||||||
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
|
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
|
||||||
"longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig),
|
"longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig),
|
||||||
|
|||||||
Reference in New Issue
Block a user