From 0aac9ba2dabcf97b22da8e9a0875b52b2113c062 Mon Sep 17 00:00:00 2001 From: Thomas Chaigneau <50595514+ChainYo@users.noreply.github.com> Date: Mon, 21 Mar 2022 21:46:31 +0100 Subject: [PATCH] Add Flaubert OnnxConfig to Transformers (#16279) * Add Flaubert to ONNX to make it available for conversion. * Fixed features for FlauBERT. fixup command remove flaubert to docs list. Co-authored-by: ChainYo --- docs/source/serialization.mdx | 1 + src/transformers/models/flaubert/__init__.py | 4 ++-- .../models/flaubert/configuration_flaubert.py | 15 +++++++++++++++ src/transformers/onnx/features.py | 10 ++++++++++ 4 files changed, 28 insertions(+), 2 deletions(-) diff --git a/docs/source/serialization.mdx b/docs/source/serialization.mdx index 81ff904a36..79e96872e9 100644 --- a/docs/source/serialization.mdx +++ b/docs/source/serialization.mdx @@ -52,6 +52,7 @@ Ready-made configurations include the following architectures: - Data2VecText - DistilBERT - ELECTRA +- FlauBERT - GPT Neo - I-BERT - LayoutLM diff --git a/src/transformers/models/flaubert/__init__.py b/src/transformers/models/flaubert/__init__.py index 538a9d6a8c..a2b6ed1ca8 100644 --- a/src/transformers/models/flaubert/__init__.py +++ b/src/transformers/models/flaubert/__init__.py @@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_tf_available, is_torch_available _import_structure = { - "configuration_flaubert": ["FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FlaubertConfig"], + "configuration_flaubert": ["FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FlaubertConfig", "FlaubertOnnxConfig"], "tokenization_flaubert": ["FlaubertTokenizer"], } @@ -52,7 +52,7 @@ if is_tf_available(): if TYPE_CHECKING: - from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig + from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig, FlaubertOnnxConfig from .tokenization_flaubert import FlaubertTokenizer if is_torch_available(): diff --git a/src/transformers/models/flaubert/configuration_flaubert.py b/src/transformers/models/flaubert/configuration_flaubert.py index 037e860069..e4ec6414c2 100644 --- a/src/transformers/models/flaubert/configuration_flaubert.py +++ b/src/transformers/models/flaubert/configuration_flaubert.py @@ -14,6 +14,10 @@ # limitations under the License. """ Flaubert configuration, based on XLM.""" +from collections import OrderedDict +from typing import Mapping + +from ...onnx import OnnxConfig from ...utils import logging from ..xlm.configuration_xlm import XLMConfig @@ -137,3 +141,14 @@ class FlaubertConfig(XLMConfig): self.layerdrop = layerdrop self.pre_norm = pre_norm super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs) + + +class FlaubertOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ] + ) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index cff3c45215..4e5bd8e9d3 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -8,6 +8,7 @@ from ..models.bert import BertOnnxConfig from ..models.camembert import CamembertOnnxConfig from ..models.distilbert import DistilBertOnnxConfig from ..models.electra import ElectraOnnxConfig +from ..models.flaubert import FlaubertOnnxConfig from ..models.gpt2 import GPT2OnnxConfig from ..models.gpt_neo import GPTNeoOnnxConfig from ..models.ibert import IBertOnnxConfig @@ -179,6 +180,15 @@ class FeaturesManager: "question-answering", onnx_config_cls=DistilBertOnnxConfig, ), + "flaubert": supported_features_mapping( + "default", + "masked-lm", + "causal-lm", + "sequence-classification", + "token-classification", + "question-answering", + onnx_config_cls=FlaubertOnnxConfig, + ), "marian": supported_features_mapping( "default", "default-with-past",