add mobilebert onnx configs (#17029)
* update docs of length_penalty * Revert "update docs of length_penalty" This reverts commit 466bf4800b75ec29bd2ff75bad8e8973bd98d01c. * add mobilebert onnx config * address suggestions * Update auto.mdx * Update __init__.py * Update features.py
This commit is contained in:
@@ -194,6 +194,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its
|
|||||||
|
|
||||||
[[autodoc]] TFAutoModelForMultipleChoice
|
[[autodoc]] TFAutoModelForMultipleChoice
|
||||||
|
|
||||||
|
## TFAutoModelForNextSentencePrediction
|
||||||
|
|
||||||
|
[[autodoc]] TFAutoModelForNextSentencePrediction
|
||||||
|
|
||||||
## TFAutoModelForTableQuestionAnswering
|
## TFAutoModelForTableQuestionAnswering
|
||||||
|
|
||||||
[[autodoc]] TFAutoModelForTableQuestionAnswering
|
[[autodoc]] TFAutoModelForTableQuestionAnswering
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- M2M100
|
- M2M100
|
||||||
- Marian
|
- Marian
|
||||||
- mBART
|
- mBART
|
||||||
|
- MobileBert
|
||||||
- OpenAI GPT-2
|
- OpenAI GPT-2
|
||||||
- PLBart
|
- PLBart
|
||||||
- RoBERTa
|
- RoBERTa
|
||||||
|
|||||||
@@ -1798,6 +1798,7 @@ if is_tf_available():
|
|||||||
"TFAutoModelForSeq2SeqLM",
|
"TFAutoModelForSeq2SeqLM",
|
||||||
"TFAutoModelForSequenceClassification",
|
"TFAutoModelForSequenceClassification",
|
||||||
"TFAutoModelForSpeechSeq2Seq",
|
"TFAutoModelForSpeechSeq2Seq",
|
||||||
|
"TFAutoModelForNextSentencePrediction",
|
||||||
"TFAutoModelForTableQuestionAnswering",
|
"TFAutoModelForTableQuestionAnswering",
|
||||||
"TFAutoModelForTokenClassification",
|
"TFAutoModelForTokenClassification",
|
||||||
"TFAutoModelForVision2Seq",
|
"TFAutoModelForVision2Seq",
|
||||||
@@ -3964,6 +3965,7 @@ if TYPE_CHECKING:
|
|||||||
TFAutoModelForImageClassification,
|
TFAutoModelForImageClassification,
|
||||||
TFAutoModelForMaskedLM,
|
TFAutoModelForMaskedLM,
|
||||||
TFAutoModelForMultipleChoice,
|
TFAutoModelForMultipleChoice,
|
||||||
|
TFAutoModelForNextSentencePrediction,
|
||||||
TFAutoModelForPreTraining,
|
TFAutoModelForPreTraining,
|
||||||
TFAutoModelForQuestionAnswering,
|
TFAutoModelForQuestionAnswering,
|
||||||
TFAutoModelForSeq2SeqLM,
|
TFAutoModelForSeq2SeqLM,
|
||||||
|
|||||||
@@ -108,6 +108,7 @@ if is_tf_available():
|
|||||||
"TFAutoModelForSeq2SeqLM",
|
"TFAutoModelForSeq2SeqLM",
|
||||||
"TFAutoModelForSequenceClassification",
|
"TFAutoModelForSequenceClassification",
|
||||||
"TFAutoModelForSpeechSeq2Seq",
|
"TFAutoModelForSpeechSeq2Seq",
|
||||||
|
"TFAutoModelForNextSentencePrediction",
|
||||||
"TFAutoModelForTableQuestionAnswering",
|
"TFAutoModelForTableQuestionAnswering",
|
||||||
"TFAutoModelForTokenClassification",
|
"TFAutoModelForTokenClassification",
|
||||||
"TFAutoModelForVision2Seq",
|
"TFAutoModelForVision2Seq",
|
||||||
@@ -224,6 +225,7 @@ if TYPE_CHECKING:
|
|||||||
TFAutoModelForImageClassification,
|
TFAutoModelForImageClassification,
|
||||||
TFAutoModelForMaskedLM,
|
TFAutoModelForMaskedLM,
|
||||||
TFAutoModelForMultipleChoice,
|
TFAutoModelForMultipleChoice,
|
||||||
|
TFAutoModelForNextSentencePrediction,
|
||||||
TFAutoModelForPreTraining,
|
TFAutoModelForPreTraining,
|
||||||
TFAutoModelForQuestionAnswering,
|
TFAutoModelForQuestionAnswering,
|
||||||
TFAutoModelForSeq2SeqLM,
|
TFAutoModelForSeq2SeqLM,
|
||||||
|
|||||||
@@ -22,7 +22,11 @@ from ...utils import _LazyModule, is_tf_available, is_tokenizers_available, is_t
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig"],
|
"configuration_mobilebert": [
|
||||||
|
"MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||||
|
"MobileBertConfig",
|
||||||
|
"MobileBertOnnxConfig",
|
||||||
|
],
|
||||||
"tokenization_mobilebert": ["MobileBertTokenizer"],
|
"tokenization_mobilebert": ["MobileBertTokenizer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,7 +66,11 @@ if is_tf_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig
|
from .configuration_mobilebert import (
|
||||||
|
MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
MobileBertConfig,
|
||||||
|
MobileBertOnnxConfig,
|
||||||
|
)
|
||||||
from .tokenization_mobilebert import MobileBertTokenizer
|
from .tokenization_mobilebert import MobileBertTokenizer
|
||||||
|
|
||||||
if is_tokenizers_available():
|
if is_tokenizers_available():
|
||||||
|
|||||||
@@ -13,8 +13,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" MobileBERT model configuration"""
|
""" MobileBERT model configuration"""
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -165,3 +168,20 @@ class MobileBertConfig(PretrainedConfig):
|
|||||||
self.true_hidden_size = hidden_size
|
self.true_hidden_size = hidden_size
|
||||||
|
|
||||||
self.classifier_dropout = classifier_dropout
|
self.classifier_dropout = classifier_dropout
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Bert->MobileBert
|
||||||
|
class MobileBertOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task == "multiple-choice":
|
||||||
|
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
||||||
|
else:
|
||||||
|
dynamic_axis = {0: "batch", 1: "sequence"}
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", dynamic_axis),
|
||||||
|
("attention_mask", dynamic_axis),
|
||||||
|
("token_type_ids", dynamic_axis),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from ..models.layoutlm import LayoutLMOnnxConfig
|
|||||||
from ..models.m2m_100 import M2M100OnnxConfig
|
from ..models.m2m_100 import M2M100OnnxConfig
|
||||||
from ..models.marian import MarianOnnxConfig
|
from ..models.marian import MarianOnnxConfig
|
||||||
from ..models.mbart import MBartOnnxConfig
|
from ..models.mbart import MBartOnnxConfig
|
||||||
|
from ..models.mobilebert import MobileBertOnnxConfig
|
||||||
from ..models.roberta import RobertaOnnxConfig
|
from ..models.roberta import RobertaOnnxConfig
|
||||||
from ..models.roformer import RoFormerOnnxConfig
|
from ..models.roformer import RoFormerOnnxConfig
|
||||||
from ..models.t5 import T5OnnxConfig
|
from ..models.t5 import T5OnnxConfig
|
||||||
@@ -44,6 +45,7 @@ if is_torch_available():
|
|||||||
AutoModelForMaskedImageModeling,
|
AutoModelForMaskedImageModeling,
|
||||||
AutoModelForMaskedLM,
|
AutoModelForMaskedLM,
|
||||||
AutoModelForMultipleChoice,
|
AutoModelForMultipleChoice,
|
||||||
|
AutoModelForNextSentencePrediction,
|
||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
@@ -55,6 +57,7 @@ if is_tf_available():
|
|||||||
TFAutoModelForCausalLM,
|
TFAutoModelForCausalLM,
|
||||||
TFAutoModelForMaskedLM,
|
TFAutoModelForMaskedLM,
|
||||||
TFAutoModelForMultipleChoice,
|
TFAutoModelForMultipleChoice,
|
||||||
|
TFAutoModelForNextSentencePrediction,
|
||||||
TFAutoModelForQuestionAnswering,
|
TFAutoModelForQuestionAnswering,
|
||||||
TFAutoModelForSeq2SeqLM,
|
TFAutoModelForSeq2SeqLM,
|
||||||
TFAutoModelForSequenceClassification,
|
TFAutoModelForSequenceClassification,
|
||||||
@@ -108,6 +111,7 @@ class FeaturesManager:
|
|||||||
"question-answering": AutoModelForQuestionAnswering,
|
"question-answering": AutoModelForQuestionAnswering,
|
||||||
"image-classification": AutoModelForImageClassification,
|
"image-classification": AutoModelForImageClassification,
|
||||||
"masked-im": AutoModelForMaskedImageModeling,
|
"masked-im": AutoModelForMaskedImageModeling,
|
||||||
|
"next-sentence-prediction": AutoModelForNextSentencePrediction,
|
||||||
}
|
}
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
_TASKS_TO_TF_AUTOMODELS = {
|
_TASKS_TO_TF_AUTOMODELS = {
|
||||||
@@ -119,6 +123,7 @@ class FeaturesManager:
|
|||||||
"token-classification": TFAutoModelForTokenClassification,
|
"token-classification": TFAutoModelForTokenClassification,
|
||||||
"multiple-choice": TFAutoModelForMultipleChoice,
|
"multiple-choice": TFAutoModelForMultipleChoice,
|
||||||
"question-answering": TFAutoModelForQuestionAnswering,
|
"question-answering": TFAutoModelForQuestionAnswering,
|
||||||
|
"next-sentence-prediction": TFAutoModelForNextSentencePrediction,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Set of model topologies we support associated to the features supported by each topology and the factory
|
# Set of model topologies we support associated to the features supported by each topology and the factory
|
||||||
@@ -153,6 +158,7 @@ class FeaturesManager:
|
|||||||
"multiple-choice",
|
"multiple-choice",
|
||||||
"token-classification",
|
"token-classification",
|
||||||
"question-answering",
|
"question-answering",
|
||||||
|
"next-sentence-prediction",
|
||||||
onnx_config_cls=BertOnnxConfig,
|
onnx_config_cls=BertOnnxConfig,
|
||||||
),
|
),
|
||||||
"big-bird": supported_features_mapping(
|
"big-bird": supported_features_mapping(
|
||||||
@@ -316,6 +322,16 @@ class FeaturesManager:
|
|||||||
"question-answering",
|
"question-answering",
|
||||||
onnx_config_cls=MBartOnnxConfig,
|
onnx_config_cls=MBartOnnxConfig,
|
||||||
),
|
),
|
||||||
|
"mobilebert": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"masked-lm",
|
||||||
|
"next-sentence-prediction",
|
||||||
|
"sequence-classification",
|
||||||
|
"multiple-choice",
|
||||||
|
"token-classification",
|
||||||
|
"question-answering",
|
||||||
|
onnx_config_cls=MobileBertOnnxConfig,
|
||||||
|
),
|
||||||
"m2m-100": supported_features_mapping(
|
"m2m-100": supported_features_mapping(
|
||||||
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=M2M100OnnxConfig
|
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=M2M100OnnxConfig
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -180,6 +180,7 @@ PYTORCH_EXPORT_MODELS = {
|
|||||||
("electra", "google/electra-base-generator"),
|
("electra", "google/electra-base-generator"),
|
||||||
("roberta", "roberta-base"),
|
("roberta", "roberta-base"),
|
||||||
("roformer", "junnyu/roformer_chinese_base"),
|
("roformer", "junnyu/roformer_chinese_base"),
|
||||||
|
("mobilebert", "google/mobilebert-uncased"),
|
||||||
("xlm-roberta", "xlm-roberta-base"),
|
("xlm-roberta", "xlm-roberta-base"),
|
||||||
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
||||||
("vit", "google/vit-base-patch16-224"),
|
("vit", "google/vit-base-patch16-224"),
|
||||||
|
|||||||
Reference in New Issue
Block a user