Add RemBERT ONNX config (#20520)
* rembert onnx config * formatting Co-authored-by: Ho <erincho@bcd0745f972b.ant.amazon.com>
This commit is contained in:
@@ -93,6 +93,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- OWL-ViT
|
- OWL-ViT
|
||||||
- Perceiver
|
- Perceiver
|
||||||
- PLBart
|
- PLBart
|
||||||
|
- RemBERT
|
||||||
- ResNet
|
- ResNet
|
||||||
- RoBERTa
|
- RoBERTa
|
||||||
- RoFormer
|
- RoFormer
|
||||||
|
|||||||
@@ -28,7 +28,9 @@ from ...utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {"configuration_rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig"]}
|
_import_structure = {
|
||||||
|
"configuration_rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig", "RemBertOnnxConfig"]
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_sentencepiece_available():
|
if not is_sentencepiece_available():
|
||||||
@@ -88,7 +90,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig
|
from .configuration_rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig, RemBertOnnxConfig
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_sentencepiece_available():
|
if not is_sentencepiece_available():
|
||||||
|
|||||||
@@ -13,8 +13,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" RemBERT model configuration"""
|
""" RemBERT model configuration"""
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -135,3 +138,23 @@ class RemBertConfig(PretrainedConfig):
|
|||||||
self.layer_norm_eps = layer_norm_eps
|
self.layer_norm_eps = layer_norm_eps
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
self.tie_word_embeddings = False
|
self.tie_word_embeddings = False
|
||||||
|
|
||||||
|
|
||||||
|
class RemBertOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task == "multiple-choice":
|
||||||
|
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
||||||
|
else:
|
||||||
|
dynamic_axis = {0: "batch", 1: "sequence"}
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", dynamic_axis),
|
||||||
|
("attention_mask", dynamic_axis),
|
||||||
|
("token_type_ids", dynamic_axis),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def atol_for_validation(self) -> float:
|
||||||
|
return 1e-4
|
||||||
|
|||||||
@@ -447,6 +447,16 @@ class FeaturesManager:
|
|||||||
"sequence-classification",
|
"sequence-classification",
|
||||||
onnx_config_cls="models.perceiver.PerceiverOnnxConfig",
|
onnx_config_cls="models.perceiver.PerceiverOnnxConfig",
|
||||||
),
|
),
|
||||||
|
"rembert": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"masked-lm",
|
||||||
|
"causal-lm",
|
||||||
|
"sequence-classification",
|
||||||
|
"multiple-choice",
|
||||||
|
"token-classification",
|
||||||
|
"question-answering",
|
||||||
|
onnx_config_cls="models.rembert.RemBertOnnxConfig",
|
||||||
|
),
|
||||||
"resnet": supported_features_mapping(
|
"resnet": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"image-classification",
|
"image-classification",
|
||||||
|
|||||||
@@ -210,6 +210,7 @@ PYTORCH_EXPORT_MODELS = {
|
|||||||
("owlvit", "google/owlvit-base-patch32"),
|
("owlvit", "google/owlvit-base-patch32"),
|
||||||
("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("masked-lm", "sequence-classification")),
|
("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("masked-lm", "sequence-classification")),
|
||||||
("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("image-classification",)),
|
("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("image-classification",)),
|
||||||
|
("rembert", "google/rembert"),
|
||||||
("resnet", "microsoft/resnet-50"),
|
("resnet", "microsoft/resnet-50"),
|
||||||
("roberta", "hf-internal-testing/tiny-random-RobertaModel"),
|
("roberta", "hf-internal-testing/tiny-random-RobertaModel"),
|
||||||
("roformer", "hf-internal-testing/tiny-random-RoFormerModel"),
|
("roformer", "hf-internal-testing/tiny-random-RoFormerModel"),
|
||||||
|
|||||||
Reference in New Issue
Block a user