add mobilebert onnx configs (#17029)

* update docs of length_penalty

* Revert "update docs of length_penalty"

This reverts commit 466bf4800b75ec29bd2ff75bad8e8973bd98d01c.

* add mobilebert onnx config

* address suggestions

* Update auto.mdx

* Update __init__.py

* Update features.py
This commit is contained in:
Manan Dey
2022-05-09 20:06:53 +05:30
committed by GitHub
parent a021f2b90c
commit dc3645dc9c
8 changed files with 56 additions and 2 deletions

View File

@@ -194,6 +194,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its
[[autodoc]] TFAutoModelForMultipleChoice [[autodoc]] TFAutoModelForMultipleChoice
## TFAutoModelForNextSentencePrediction
[[autodoc]] TFAutoModelForNextSentencePrediction
## TFAutoModelForTableQuestionAnswering ## TFAutoModelForTableQuestionAnswering
[[autodoc]] TFAutoModelForTableQuestionAnswering [[autodoc]] TFAutoModelForTableQuestionAnswering

View File

@@ -68,6 +68,7 @@ Ready-made configurations include the following architectures:
- M2M100 - M2M100
- Marian - Marian
- mBART - mBART
- MobileBert
- OpenAI GPT-2 - OpenAI GPT-2
- PLBart - PLBart
- RoBERTa - RoBERTa

View File

@@ -1798,6 +1798,7 @@ if is_tf_available():
"TFAutoModelForSeq2SeqLM", "TFAutoModelForSeq2SeqLM",
"TFAutoModelForSequenceClassification", "TFAutoModelForSequenceClassification",
"TFAutoModelForSpeechSeq2Seq", "TFAutoModelForSpeechSeq2Seq",
"TFAutoModelForNextSentencePrediction",
"TFAutoModelForTableQuestionAnswering", "TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTokenClassification", "TFAutoModelForTokenClassification",
"TFAutoModelForVision2Seq", "TFAutoModelForVision2Seq",
@@ -3964,6 +3965,7 @@ if TYPE_CHECKING:
TFAutoModelForImageClassification, TFAutoModelForImageClassification,
TFAutoModelForMaskedLM, TFAutoModelForMaskedLM,
TFAutoModelForMultipleChoice, TFAutoModelForMultipleChoice,
TFAutoModelForNextSentencePrediction,
TFAutoModelForPreTraining, TFAutoModelForPreTraining,
TFAutoModelForQuestionAnswering, TFAutoModelForQuestionAnswering,
TFAutoModelForSeq2SeqLM, TFAutoModelForSeq2SeqLM,

View File

@@ -108,6 +108,7 @@ if is_tf_available():
"TFAutoModelForSeq2SeqLM", "TFAutoModelForSeq2SeqLM",
"TFAutoModelForSequenceClassification", "TFAutoModelForSequenceClassification",
"TFAutoModelForSpeechSeq2Seq", "TFAutoModelForSpeechSeq2Seq",
"TFAutoModelForNextSentencePrediction",
"TFAutoModelForTableQuestionAnswering", "TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTokenClassification", "TFAutoModelForTokenClassification",
"TFAutoModelForVision2Seq", "TFAutoModelForVision2Seq",
@@ -224,6 +225,7 @@ if TYPE_CHECKING:
TFAutoModelForImageClassification, TFAutoModelForImageClassification,
TFAutoModelForMaskedLM, TFAutoModelForMaskedLM,
TFAutoModelForMultipleChoice, TFAutoModelForMultipleChoice,
TFAutoModelForNextSentencePrediction,
TFAutoModelForPreTraining, TFAutoModelForPreTraining,
TFAutoModelForQuestionAnswering, TFAutoModelForQuestionAnswering,
TFAutoModelForSeq2SeqLM, TFAutoModelForSeq2SeqLM,

View File

@@ -22,7 +22,11 @@ from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_t
_import_structure = { _import_structure = {
"configuration_mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig"], "configuration_mobilebert": [
"MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"MobileBertConfig",
"MobileBertOnnxConfig",
],
"tokenization_mobilebert": ["MobileBertTokenizer"], "tokenization_mobilebert": ["MobileBertTokenizer"],
} }
@@ -62,7 +66,11 @@ if is_tf_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig from .configuration_mobilebert import (
MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
MobileBertConfig,
MobileBertOnnxConfig,
)
from .tokenization_mobilebert import MobileBertTokenizer from .tokenization_mobilebert import MobileBertTokenizer
if is_tokenizers_available(): if is_tokenizers_available():

View File

@@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" MobileBERT model configuration""" """ MobileBERT model configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
@@ -165,3 +168,20 @@ class MobileBertConfig(PretrainedConfig):
self.true_hidden_size = hidden_size self.true_hidden_size = hidden_size
self.classifier_dropout = classifier_dropout self.classifier_dropout = classifier_dropout
# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Bert->MobileBert
class MobileBertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)

View File

@@ -25,6 +25,7 @@ from ..models.layoutlm import LayoutLMOnnxConfig
from ..models.m2m_100 import M2M100OnnxConfig from ..models.m2m_100 import M2M100OnnxConfig
from ..models.marian import MarianOnnxConfig from ..models.marian import MarianOnnxConfig
from ..models.mbart import MBartOnnxConfig from ..models.mbart import MBartOnnxConfig
from ..models.mobilebert import MobileBertOnnxConfig
from ..models.roberta import RobertaOnnxConfig from ..models.roberta import RobertaOnnxConfig
from ..models.roformer import RoFormerOnnxConfig from ..models.roformer import RoFormerOnnxConfig
from ..models.t5 import T5OnnxConfig from ..models.t5 import T5OnnxConfig
@@ -44,6 +45,7 @@ if is_torch_available():
AutoModelForMaskedImageModeling, AutoModelForMaskedImageModeling,
AutoModelForMaskedLM, AutoModelForMaskedLM,
AutoModelForMultipleChoice, AutoModelForMultipleChoice,
AutoModelForNextSentencePrediction,
AutoModelForQuestionAnswering, AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
@@ -55,6 +57,7 @@ if is_tf_available():
TFAutoModelForCausalLM, TFAutoModelForCausalLM,
TFAutoModelForMaskedLM, TFAutoModelForMaskedLM,
TFAutoModelForMultipleChoice, TFAutoModelForMultipleChoice,
TFAutoModelForNextSentencePrediction,
TFAutoModelForQuestionAnswering, TFAutoModelForQuestionAnswering,
TFAutoModelForSeq2SeqLM, TFAutoModelForSeq2SeqLM,
TFAutoModelForSequenceClassification, TFAutoModelForSequenceClassification,
@@ -108,6 +111,7 @@ class FeaturesManager:
"question-answering": AutoModelForQuestionAnswering, "question-answering": AutoModelForQuestionAnswering,
"image-classification": AutoModelForImageClassification, "image-classification": AutoModelForImageClassification,
"masked-im": AutoModelForMaskedImageModeling, "masked-im": AutoModelForMaskedImageModeling,
"next-sentence-prediction": AutoModelForNextSentencePrediction,
} }
if is_tf_available(): if is_tf_available():
_TASKS_TO_TF_AUTOMODELS = { _TASKS_TO_TF_AUTOMODELS = {
@@ -119,6 +123,7 @@ class FeaturesManager:
"token-classification": TFAutoModelForTokenClassification, "token-classification": TFAutoModelForTokenClassification,
"multiple-choice": TFAutoModelForMultipleChoice, "multiple-choice": TFAutoModelForMultipleChoice,
"question-answering": TFAutoModelForQuestionAnswering, "question-answering": TFAutoModelForQuestionAnswering,
"next-sentence-prediction": TFAutoModelForNextSentencePrediction,
} }
# Set of model topologies we support associated to the features supported by each topology and the factory # Set of model topologies we support associated to the features supported by each topology and the factory
@@ -153,6 +158,7 @@ class FeaturesManager:
"multiple-choice", "multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
"next-sentence-prediction",
onnx_config_cls=BertOnnxConfig, onnx_config_cls=BertOnnxConfig,
), ),
"big-bird": supported_features_mapping( "big-bird": supported_features_mapping(
@@ -316,6 +322,16 @@ class FeaturesManager:
"question-answering", "question-answering",
onnx_config_cls=MBartOnnxConfig, onnx_config_cls=MBartOnnxConfig,
), ),
"mobilebert": supported_features_mapping(
"default",
"masked-lm",
"next-sentence-prediction",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=MobileBertOnnxConfig,
),
"m2m-100": supported_features_mapping( "m2m-100": supported_features_mapping(
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=M2M100OnnxConfig "default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=M2M100OnnxConfig
), ),

View File

@@ -180,6 +180,7 @@ PYTORCH_EXPORT_MODELS = {
("electra", "google/electra-base-generator"), ("electra", "google/electra-base-generator"),
("roberta", "roberta-base"), ("roberta", "roberta-base"),
("roformer", "junnyu/roformer_chinese_base"), ("roformer", "junnyu/roformer_chinese_base"),
("mobilebert", "google/mobilebert-uncased"),
("xlm-roberta", "xlm-roberta-base"), ("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"), ("layoutlm", "microsoft/layoutlm-base-uncased"),
("vit", "google/vit-base-patch16-224"), ("vit", "google/vit-base-patch16-224"),