Add RemBERT ONNX config (#20520)

* rembert onnx config

* formatting

Co-authored-by: Ho <erincho@bcd0745f972b.ant.amazon.com>
This commit is contained in:
Erin
2022-12-05 08:39:09 -08:00
committed by GitHub
parent afe2a466bb
commit 87282cb73c
5 changed files with 39 additions and 2 deletions

View File

@@ -93,6 +93,7 @@ Ready-made configurations include the following architectures:
- OWL-ViT - OWL-ViT
- Perceiver - Perceiver
- PLBart - PLBart
- RemBERT
- ResNet - ResNet
- RoBERTa - RoBERTa
- RoFormer - RoFormer

View File

@@ -28,7 +28,9 @@ from ...utils import (
) )
_import_structure = {"configuration_rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig"]} _import_structure = {
"configuration_rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig", "RemBertOnnxConfig"]
}
try: try:
if not is_sentencepiece_available(): if not is_sentencepiece_available():
@@ -88,7 +90,7 @@ else:
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig from .configuration_rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig, RemBertOnnxConfig
try: try:
if not is_sentencepiece_available(): if not is_sentencepiece_available():

View File

@@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" RemBERT model configuration""" """ RemBERT model configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
@@ -135,3 +138,23 @@ class RemBertConfig(PretrainedConfig):
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache self.use_cache = use_cache
self.tie_word_embeddings = False self.tie_word_embeddings = False
class RemBertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-4

View File

@@ -447,6 +447,16 @@ class FeaturesManager:
"sequence-classification", "sequence-classification",
onnx_config_cls="models.perceiver.PerceiverOnnxConfig", onnx_config_cls="models.perceiver.PerceiverOnnxConfig",
), ),
"rembert": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls="models.rembert.RemBertOnnxConfig",
),
"resnet": supported_features_mapping( "resnet": supported_features_mapping(
"default", "default",
"image-classification", "image-classification",

View File

@@ -210,6 +210,7 @@ PYTORCH_EXPORT_MODELS = {
("owlvit", "google/owlvit-base-patch32"), ("owlvit", "google/owlvit-base-patch32"),
("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("masked-lm", "sequence-classification")), ("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("masked-lm", "sequence-classification")),
("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("image-classification",)), ("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("image-classification",)),
("rembert", "google/rembert"),
("resnet", "microsoft/resnet-50"), ("resnet", "microsoft/resnet-50"),
("roberta", "hf-internal-testing/tiny-random-RobertaModel"), ("roberta", "hf-internal-testing/tiny-random-RobertaModel"),
("roformer", "hf-internal-testing/tiny-random-RoFormerModel"), ("roformer", "hf-internal-testing/tiny-random-RoFormerModel"),