Add onnx config for RoFormer (#16861)

* add roformer onnx config
This commit is contained in:
Krishna Sirumalla
2022-04-26 10:51:15 -04:00
committed by GitHub
parent 8afaaa26f5
commit aaee4038c3
5 changed files with 37 additions and 2 deletions

View File

@@ -70,6 +70,7 @@ Ready-made configurations include the following architectures:
- OpenAI GPT-2 - OpenAI GPT-2
- PLBart - PLBart
- RoBERTa - RoBERTa
- RoFormer
- T5 - T5
- TAPEX - TAPEX
- ViT - ViT

View File

@@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokeniz
_import_structure = { _import_structure = {
"configuration_roformer": ["ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoFormerConfig"], "configuration_roformer": ["ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoFormerConfig", "RoFormerOnnxConfig"],
"tokenization_roformer": ["RoFormerTokenizer"], "tokenization_roformer": ["RoFormerTokenizer"],
} }
@@ -73,7 +73,7 @@ if is_flax_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig from .configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig, RoFormerOnnxConfig
from .tokenization_roformer import RoFormerTokenizer from .tokenization_roformer import RoFormerTokenizer
if is_tokenizers_available(): if is_tokenizers_available():

View File

@@ -14,7 +14,11 @@
# limitations under the License. # limitations under the License.
""" RoFormer model configuration""" """ RoFormer model configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
@@ -131,3 +135,20 @@ class RoFormerConfig(PretrainedConfig):
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.rotary_value = rotary_value self.rotary_value = rotary_value
self.use_cache = use_cache self.use_cache = use_cache
class RoFormerOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)

View File

@@ -25,6 +25,7 @@ from ..models.m2m_100 import M2M100OnnxConfig
from ..models.marian import MarianOnnxConfig from ..models.marian import MarianOnnxConfig
from ..models.mbart import MBartOnnxConfig from ..models.mbart import MBartOnnxConfig
from ..models.roberta import RobertaOnnxConfig from ..models.roberta import RobertaOnnxConfig
from ..models.roformer import RoFormerOnnxConfig
from ..models.t5 import T5OnnxConfig from ..models.t5 import T5OnnxConfig
from ..models.vit import ViTOnnxConfig from ..models.vit import ViTOnnxConfig
from ..models.xlm_roberta import XLMRobertaOnnxConfig from ..models.xlm_roberta import XLMRobertaOnnxConfig
@@ -333,6 +334,17 @@ class FeaturesManager:
"question-answering", "question-answering",
onnx_config_cls=Data2VecTextOnnxConfig, onnx_config_cls=Data2VecTextOnnxConfig,
), ),
"roformer": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"token-classification",
"multiple-choice",
"question-answering",
"token-classification",
onnx_config_cls=RoFormerOnnxConfig,
),
} }
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values()))) AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))

View File

@@ -179,6 +179,7 @@ PYTORCH_EXPORT_MODELS = {
("distilbert", "distilbert-base-cased"), ("distilbert", "distilbert-base-cased"),
("electra", "google/electra-base-generator"), ("electra", "google/electra-base-generator"),
("roberta", "roberta-base"), ("roberta", "roberta-base"),
("roformer", "junnyu/roformer_chinese_base"),
("xlm-roberta", "xlm-roberta-base"), ("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"), ("layoutlm", "microsoft/layoutlm-base-uncased"),
("vit", "google/vit-base-patch16-224"), ("vit", "google/vit-base-patch16-224"),