From dc3645dc9c72c72481fdb6ed44ff822764347883 Mon Sep 17 00:00:00 2001 From: Manan Dey Date: Mon, 9 May 2022 20:06:53 +0530 Subject: [PATCH] add `mobilebert` onnx configs (#17029) * update docs of length_penalty * Revert "update docs of length_penalty" This reverts commit 466bf4800b75ec29bd2ff75bad8e8973bd98d01c. * add mobilebert onnx config * address suggestions * Update auto.mdx * Update __init__.py * Update features.py --- docs/source/en/model_doc/auto.mdx | 4 ++++ docs/source/en/serialization.mdx | 1 + src/transformers/__init__.py | 2 ++ src/transformers/models/auto/__init__.py | 2 ++ .../models/mobilebert/__init__.py | 12 +++++++++-- .../mobilebert/configuration_mobilebert.py | 20 +++++++++++++++++++ src/transformers/onnx/features.py | 16 +++++++++++++++ tests/onnx/test_onnx_v2.py | 1 + 8 files changed, 56 insertions(+), 2 deletions(-) diff --git a/docs/source/en/model_doc/auto.mdx b/docs/source/en/model_doc/auto.mdx index d941b00318..4a4b59e9c1 100644 --- a/docs/source/en/model_doc/auto.mdx +++ b/docs/source/en/model_doc/auto.mdx @@ -194,6 +194,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its [[autodoc]] TFAutoModelForMultipleChoice +## TFAutoModelForNextSentencePrediction + +[[autodoc]] TFAutoModelForNextSentencePrediction + ## TFAutoModelForTableQuestionAnswering [[autodoc]] TFAutoModelForTableQuestionAnswering diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index 0866632020..510fbf9363 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -68,6 +68,7 @@ Ready-made configurations include the following architectures: - M2M100 - Marian - mBART +- MobileBert - OpenAI GPT-2 - PLBart - RoBERTa diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 6d976ef6f2..2c41ff883f 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1798,6 +1798,7 @@ if is_tf_available(): "TFAutoModelForSeq2SeqLM", "TFAutoModelForSequenceClassification", "TFAutoModelForSpeechSeq2Seq", + "TFAutoModelForNextSentencePrediction", "TFAutoModelForTableQuestionAnswering", "TFAutoModelForTokenClassification", "TFAutoModelForVision2Seq", @@ -3964,6 +3965,7 @@ if TYPE_CHECKING: TFAutoModelForImageClassification, TFAutoModelForMaskedLM, TFAutoModelForMultipleChoice, + TFAutoModelForNextSentencePrediction, TFAutoModelForPreTraining, TFAutoModelForQuestionAnswering, TFAutoModelForSeq2SeqLM, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 6dace993cd..fa34a11964 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -108,6 +108,7 @@ if is_tf_available(): "TFAutoModelForSeq2SeqLM", "TFAutoModelForSequenceClassification", "TFAutoModelForSpeechSeq2Seq", + "TFAutoModelForNextSentencePrediction", "TFAutoModelForTableQuestionAnswering", "TFAutoModelForTokenClassification", "TFAutoModelForVision2Seq", @@ -224,6 +225,7 @@ if TYPE_CHECKING: TFAutoModelForImageClassification, TFAutoModelForMaskedLM, TFAutoModelForMultipleChoice, + TFAutoModelForNextSentencePrediction, TFAutoModelForPreTraining, TFAutoModelForQuestionAnswering, TFAutoModelForSeq2SeqLM, diff --git a/src/transformers/models/mobilebert/__init__.py b/src/transformers/models/mobilebert/__init__.py index 505dabe187..b35fe8a9c1 100644 --- a/src/transformers/models/mobilebert/__init__.py +++ b/src/transformers/models/mobilebert/__init__.py @@ -22,7 +22,11 @@ from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_t _import_structure = { - "configuration_mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig"], + "configuration_mobilebert": [ + "MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", + "MobileBertConfig", + "MobileBertOnnxConfig", + ], "tokenization_mobilebert": ["MobileBertTokenizer"], } @@ -62,7 +66,11 @@ if is_tf_available(): if TYPE_CHECKING: - from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig + from .configuration_mobilebert import ( + MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, + MobileBertConfig, + MobileBertOnnxConfig, + ) from .tokenization_mobilebert import MobileBertTokenizer if is_tokenizers_available(): diff --git a/src/transformers/models/mobilebert/configuration_mobilebert.py b/src/transformers/models/mobilebert/configuration_mobilebert.py index 27863235b3..73b8844ed7 100644 --- a/src/transformers/models/mobilebert/configuration_mobilebert.py +++ b/src/transformers/models/mobilebert/configuration_mobilebert.py @@ -13,8 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """ MobileBERT model configuration""" +from collections import OrderedDict +from typing import Mapping from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -165,3 +168,20 @@ class MobileBertConfig(PretrainedConfig): self.true_hidden_size = hidden_size self.classifier_dropout = classifier_dropout + + +# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Bert->MobileBert +class MobileBertOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ("token_type_ids", dynamic_axis), + ] + ) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 852a8e9071..c75cef897c 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -25,6 +25,7 @@ from ..models.layoutlm import LayoutLMOnnxConfig from ..models.m2m_100 import M2M100OnnxConfig from ..models.marian import MarianOnnxConfig from ..models.mbart import MBartOnnxConfig +from ..models.mobilebert import MobileBertOnnxConfig from ..models.roberta import RobertaOnnxConfig from ..models.roformer import RoFormerOnnxConfig from ..models.t5 import T5OnnxConfig @@ -44,6 +45,7 @@ if is_torch_available(): AutoModelForMaskedImageModeling, AutoModelForMaskedLM, AutoModelForMultipleChoice, + AutoModelForNextSentencePrediction, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, @@ -55,6 +57,7 @@ if is_tf_available(): TFAutoModelForCausalLM, TFAutoModelForMaskedLM, TFAutoModelForMultipleChoice, + TFAutoModelForNextSentencePrediction, TFAutoModelForQuestionAnswering, TFAutoModelForSeq2SeqLM, TFAutoModelForSequenceClassification, @@ -108,6 +111,7 @@ class FeaturesManager: "question-answering": AutoModelForQuestionAnswering, "image-classification": AutoModelForImageClassification, "masked-im": AutoModelForMaskedImageModeling, + "next-sentence-prediction": AutoModelForNextSentencePrediction, } if is_tf_available(): _TASKS_TO_TF_AUTOMODELS = { @@ -119,6 +123,7 @@ class FeaturesManager: "token-classification": TFAutoModelForTokenClassification, "multiple-choice": TFAutoModelForMultipleChoice, "question-answering": TFAutoModelForQuestionAnswering, + "next-sentence-prediction": TFAutoModelForNextSentencePrediction, } # Set of model topologies we support associated to the features supported by each topology and the factory @@ -153,6 +158,7 @@ class FeaturesManager: "multiple-choice", "token-classification", "question-answering", + "next-sentence-prediction", onnx_config_cls=BertOnnxConfig, ), "big-bird": supported_features_mapping( @@ -316,6 +322,16 @@ class FeaturesManager: "question-answering", onnx_config_cls=MBartOnnxConfig, ), + "mobilebert": supported_features_mapping( + "default", + "masked-lm", + "next-sentence-prediction", + "sequence-classification", + "multiple-choice", + "token-classification", + "question-answering", + onnx_config_cls=MobileBertOnnxConfig, + ), "m2m-100": supported_features_mapping( "default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=M2M100OnnxConfig ), diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 9fc228f895..43a3ad45e1 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -180,6 +180,7 @@ PYTORCH_EXPORT_MODELS = { ("electra", "google/electra-base-generator"), ("roberta", "roberta-base"), ("roformer", "junnyu/roformer_chinese_base"), + ("mobilebert", "google/mobilebert-uncased"), ("xlm-roberta", "xlm-roberta-base"), ("layoutlm", "microsoft/layoutlm-base-uncased"), ("vit", "google/vit-base-patch16-224"),