From 5af38953bb05fe722c2ec5c345f54c2712ce4573 Mon Sep 17 00:00:00 2001 From: Ritik Nandwal <48522685+nandwalritik@users.noreply.github.com> Date: Tue, 31 May 2022 18:56:06 +0530 Subject: [PATCH] Added XLM onnx config (#17030) * Add onnx configuration for xlm * Add supported features for xlm * Add xlm to models exportable with onnx * Add xlm architecture to test file * Modify docs * Make code quality fixes --- docs/source/en/serialization.mdx | 1 + src/transformers/models/xlm/__init__.py | 4 ++-- .../models/xlm/configuration_xlm.py | 20 +++++++++++++++++++ src/transformers/onnx/features.py | 11 ++++++++++ tests/onnx/test_onnx_v2.py | 1 + 5 files changed, 35 insertions(+), 2 deletions(-) diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index 2bb449240b..d12e526278 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -75,6 +75,7 @@ Ready-made configurations include the following architectures: - RoFormer - T5 - ViT +- XLM - XLM-RoBERTa - XLM-RoBERTa-XL diff --git a/src/transformers/models/xlm/__init__.py b/src/transformers/models/xlm/__init__.py index 03232811fb..de9be348b9 100644 --- a/src/transformers/models/xlm/__init__.py +++ b/src/transformers/models/xlm/__init__.py @@ -22,7 +22,7 @@ from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_availabl _import_structure = { - "configuration_xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig"], + "configuration_xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMOnnxConfig"], "tokenization_xlm": ["XLMTokenizer"], } @@ -64,7 +64,7 @@ else: if TYPE_CHECKING: - from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig + from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMOnnxConfig from .tokenization_xlm import XLMTokenizer try: diff --git a/src/transformers/models/xlm/configuration_xlm.py b/src/transformers/models/xlm/configuration_xlm.py index d6f70c6671..e14ad2ec6c 100644 --- a/src/transformers/models/xlm/configuration_xlm.py +++ b/src/transformers/models/xlm/configuration_xlm.py @@ -13,8 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """ XLM configuration""" +from collections import OrderedDict +from typing import Mapping from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -228,3 +231,20 @@ class XLMConfig(PretrainedConfig): self.n_words = kwargs["n_words"] super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs) + + +# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig +class XLMOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ("token_type_ids", dynamic_axis), + ] + ) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index fcc7673bc2..dfb46f89ec 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -30,6 +30,7 @@ from ..models.roberta import RobertaOnnxConfig from ..models.roformer import RoFormerOnnxConfig from ..models.t5 import T5OnnxConfig from ..models.vit import ViTOnnxConfig +from ..models.xlm import XLMOnnxConfig from ..models.xlm_roberta import XLMRobertaOnnxConfig from ..utils import logging from .config import OnnxConfig @@ -357,6 +358,16 @@ class FeaturesManager: "vit": supported_features_mapping( "default", "image-classification", "masked-im", onnx_config_cls=ViTOnnxConfig ), + "xlm": supported_features_mapping( + "default", + "masked-lm", + "causal-lm", + "sequence-classification", + "multiple-choice", + "token-classification", + "question-answering", + onnx_config_cls=XLMOnnxConfig, + ), "xlm-roberta": supported_features_mapping( "default", "masked-lm", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 5ebef03873..bdf08c4452 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -181,6 +181,7 @@ PYTORCH_EXPORT_MODELS = { ("roberta", "roberta-base"), ("roformer", "junnyu/roformer_chinese_base"), ("mobilebert", "google/mobilebert-uncased"), + ("xlm", "xlm-clm-ende-1024"), ("xlm-roberta", "xlm-roberta-base"), ("layoutlm", "microsoft/layoutlm-base-uncased"), ("vit", "google/vit-base-patch16-224"),