diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index 87d3273228..4ae5c9a57e 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -70,6 +70,7 @@ Ready-made configurations include the following architectures: - OpenAI GPT-2 - PLBart - RoBERTa +- RoFormer - T5 - TAPEX - ViT diff --git a/src/transformers/models/roformer/__init__.py b/src/transformers/models/roformer/__init__.py index 45806f082f..ec99c5a3b8 100644 --- a/src/transformers/models/roformer/__init__.py +++ b/src/transformers/models/roformer/__init__.py @@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokeniz _import_structure = { - "configuration_roformer": ["ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoFormerConfig"], + "configuration_roformer": ["ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoFormerConfig", "RoFormerOnnxConfig"], "tokenization_roformer": ["RoFormerTokenizer"], } @@ -73,7 +73,7 @@ if is_flax_available(): if TYPE_CHECKING: - from .configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig + from .configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig, RoFormerOnnxConfig from .tokenization_roformer import RoFormerTokenizer if is_tokenizers_available(): diff --git a/src/transformers/models/roformer/configuration_roformer.py b/src/transformers/models/roformer/configuration_roformer.py index bb7961f38e..2c5de2bbbe 100644 --- a/src/transformers/models/roformer/configuration_roformer.py +++ b/src/transformers/models/roformer/configuration_roformer.py @@ -14,7 +14,11 @@ # limitations under the License. """ RoFormer model configuration""" +from collections import OrderedDict +from typing import Mapping + from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -131,3 +135,20 @@ class RoFormerConfig(PretrainedConfig): self.layer_norm_eps = layer_norm_eps self.rotary_value = rotary_value self.use_cache = use_cache + + +class RoFormerOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ("token_type_ids", dynamic_axis), + ] + ) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 4133d6918c..a4d3a49388 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -25,6 +25,7 @@ from ..models.m2m_100 import M2M100OnnxConfig from ..models.marian import MarianOnnxConfig from ..models.mbart import MBartOnnxConfig from ..models.roberta import RobertaOnnxConfig +from ..models.roformer import RoFormerOnnxConfig from ..models.t5 import T5OnnxConfig from ..models.vit import ViTOnnxConfig from ..models.xlm_roberta import XLMRobertaOnnxConfig @@ -333,6 +334,17 @@ class FeaturesManager: "question-answering", onnx_config_cls=Data2VecTextOnnxConfig, ), + "roformer": supported_features_mapping( + "default", + "masked-lm", + "causal-lm", + "sequence-classification", + "token-classification", + "multiple-choice", + "question-answering", + "token-classification", + onnx_config_cls=RoFormerOnnxConfig, + ), } AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values()))) diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 40a9645509..ea5a547639 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -179,6 +179,7 @@ PYTORCH_EXPORT_MODELS = { ("distilbert", "distilbert-base-cased"), ("electra", "google/electra-base-generator"), ("roberta", "roberta-base"), + ("roformer", "junnyu/roformer_chinese_base"), ("xlm-roberta", "xlm-roberta-base"), ("layoutlm", "microsoft/layoutlm-base-uncased"), ("vit", "google/vit-base-patch16-224"),