From c4fa908fa98c3d538462c537d29b7613dd71306e Mon Sep 17 00:00:00 2001 From: Virus Date: Tue, 11 Jan 2022 14:17:08 +0300 Subject: [PATCH] Adds IBERT to models exportable with ONNX (#14868) * Add IBertOnnxConfig and tests * add all the supported features for IBERT and remove outputs in IbertOnnxConfig * use OnnxConfig * fix codestyle * remove serialization.rst * codestyle --- docs/source/serialization.mdx | 1 + src/transformers/models/ibert/__init__.py | 4 ++-- .../models/ibert/configuration_ibert.py | 15 +++++++++++++++ src/transformers/onnx/features.py | 10 ++++++++++ tests/test_onnx_v2.py | 1 + 5 files changed, 29 insertions(+), 2 deletions(-) diff --git a/docs/source/serialization.mdx b/docs/source/serialization.mdx index 091eac083b..85b8ee8005 100644 --- a/docs/source/serialization.mdx +++ b/docs/source/serialization.mdx @@ -40,6 +40,7 @@ Ready-made configurations include the following models: - CamemBERT - DistilBERT - GPT Neo +- I-BERT - LayoutLM - Longformer - Marian diff --git a/src/transformers/models/ibert/__init__.py b/src/transformers/models/ibert/__init__.py index 2e34d1224f..9ef9780807 100644 --- a/src/transformers/models/ibert/__init__.py +++ b/src/transformers/models/ibert/__init__.py @@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_torch_available _import_structure = { - "configuration_ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"], + "configuration_ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig", "IBertOnnxConfig"], } if is_torch_available(): @@ -38,7 +38,7 @@ if is_torch_available(): ] if TYPE_CHECKING: - from .configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig + from .configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig, IBertOnnxConfig if is_torch_available(): from .modeling_ibert import ( diff --git a/src/transformers/models/ibert/configuration_ibert.py b/src/transformers/models/ibert/configuration_ibert.py index ad0fd8f927..8b96594cfe 100644 --- a/src/transformers/models/ibert/configuration_ibert.py +++ b/src/transformers/models/ibert/configuration_ibert.py @@ -15,6 +15,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """ I-BERT configuration""" +from collections import OrderedDict +from typing import Mapping + +from transformers.onnx import OnnxConfig from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -122,3 +126,14 @@ class IBertConfig(PretrainedConfig): self.position_embedding_type = position_embedding_type self.quant_mode = quant_mode self.force_dequant = force_dequant + + +class IBertOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ] + ) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 2f12a574bf..41f8970d70 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -9,6 +9,7 @@ from ..models.camembert import CamembertOnnxConfig from ..models.distilbert import DistilBertOnnxConfig from ..models.gpt2 import GPT2OnnxConfig from ..models.gpt_neo import GPTNeoOnnxConfig +from ..models.ibert import IBertOnnxConfig from ..models.layoutlm import LayoutLMOnnxConfig from ..models.longformer import LongformerOnnxConfig from ..models.marian import MarianOnnxConfig @@ -125,6 +126,15 @@ class FeaturesManager: "question-answering", onnx_config_cls=BertOnnxConfig, ), + "ibert": supported_features_mapping( + "default", + "masked-lm", + "sequence-classification", + # "multiple-choice", + "token-classification", + "question-answering", + onnx_config_cls=IBertOnnxConfig, + ), "camembert": supported_features_mapping( "default", "masked-lm", diff --git a/tests/test_onnx_v2.py b/tests/test_onnx_v2.py index fdd7694c41..53b28718a6 100644 --- a/tests/test_onnx_v2.py +++ b/tests/test_onnx_v2.py @@ -171,6 +171,7 @@ class OnnxConfigWithPastTestCaseV2(TestCase): PYTORCH_EXPORT_MODELS = { ("albert", "hf-internal-testing/tiny-albert"), ("bert", "bert-base-cased"), + ("ibert", "kssteven/ibert-roberta-base"), ("camembert", "camembert-base"), ("distilbert", "distilbert-base-cased"), # ("longFormer", "longformer-base-4096"),