From 1f60df81b24f070d4e9a1a17c76d3903fefe10bf Mon Sep 17 00:00:00 2001 From: Thomas Chaigneau <50595514+ChainYo@users.noreply.github.com> Date: Tue, 26 Oct 2021 11:22:22 +0200 Subject: [PATCH] Add Camembert to models exportable with ONNX (#14059) Add Camembert to models exportable with ONNX Co-authored-by: Thomas.Chaigneau Co-authored-by: Michael Benayoun --- docs/source/serialization.rst | 1 + src/transformers/models/camembert/__init__.py | 4 ++-- .../models/camembert/configuration_camembert.py | 15 +++++++++++++++ src/transformers/onnx/features.py | 9 +++++++++ 4 files changed, 27 insertions(+), 2 deletions(-) diff --git a/docs/source/serialization.rst b/docs/source/serialization.rst index 53e075e13c..508eb2ac7a 100644 --- a/docs/source/serialization.rst +++ b/docs/source/serialization.rst @@ -43,6 +43,7 @@ Ready-made configurations include the following models: - ALBERT - BART - BERT +- CamemBERT - DistilBERT - GPT Neo - LayoutLM diff --git a/src/transformers/models/camembert/__init__.py b/src/transformers/models/camembert/__init__.py index fc91d758a7..6079655c38 100644 --- a/src/transformers/models/camembert/__init__.py +++ b/src/transformers/models/camembert/__init__.py @@ -28,7 +28,7 @@ from ...file_utils import ( _import_structure = { - "configuration_camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig"], + "configuration_camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig", "CamembertOnnxConfig"], } if is_sentencepiece_available(): @@ -62,7 +62,7 @@ if is_tf_available(): if TYPE_CHECKING: - from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig + from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig, CamembertOnnxConfig if is_sentencepiece_available(): from .tokenization_camembert import CamembertTokenizer diff --git a/src/transformers/models/camembert/configuration_camembert.py b/src/transformers/models/camembert/configuration_camembert.py index 31f9d94a0d..8a55e1c320 100644 --- a/src/transformers/models/camembert/configuration_camembert.py +++ b/src/transformers/models/camembert/configuration_camembert.py @@ -15,6 +15,10 @@ # limitations under the License. """ CamemBERT configuration """ +from collections import OrderedDict +from typing import Mapping + +from ...onnx import OnnxConfig from ...utils import logging from ..roberta.configuration_roberta import RobertaConfig @@ -35,3 +39,14 @@ class CamembertConfig(RobertaConfig): """ model_type = "camembert" + + +class CamembertOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ] + ) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 937189b28f..d685af4cf7 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -5,6 +5,7 @@ from .. import is_torch_available from ..models.albert import AlbertOnnxConfig from ..models.bart import BartOnnxConfig from ..models.bert import BertOnnxConfig +from ..models.camembert import CamembertOnnxConfig from ..models.distilbert import DistilBertOnnxConfig from ..models.gpt2 import GPT2OnnxConfig from ..models.gpt_neo import GPTNeoOnnxConfig @@ -62,6 +63,14 @@ class FeaturesManager: "bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig), "mbart": supported_features_mapping("default", onnx_config_cls=MBartOnnxConfig), "bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig), + "camembert": supported_features_mapping( + "default", + "causal-lm", + "sequence-classification", + "token-classification", + "question-answering", + onnx_config_cls=CamembertOnnxConfig, + ), "distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig), "gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig), "longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig),