committed by
GitHub
parent
8afaaa26f5
commit
aaee4038c3
@@ -70,6 +70,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- OpenAI GPT-2
|
- OpenAI GPT-2
|
||||||
- PLBart
|
- PLBart
|
||||||
- RoBERTa
|
- RoBERTa
|
||||||
|
- RoFormer
|
||||||
- T5
|
- T5
|
||||||
- TAPEX
|
- TAPEX
|
||||||
- ViT
|
- ViT
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokeniz
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_roformer": ["ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoFormerConfig"],
|
"configuration_roformer": ["ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoFormerConfig", "RoFormerOnnxConfig"],
|
||||||
"tokenization_roformer": ["RoFormerTokenizer"],
|
"tokenization_roformer": ["RoFormerTokenizer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -73,7 +73,7 @@ if is_flax_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig
|
from .configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig, RoFormerOnnxConfig
|
||||||
from .tokenization_roformer import RoFormerTokenizer
|
from .tokenization_roformer import RoFormerTokenizer
|
||||||
|
|
||||||
if is_tokenizers_available():
|
if is_tokenizers_available():
|
||||||
|
|||||||
@@ -14,7 +14,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" RoFormer model configuration"""
|
""" RoFormer model configuration"""
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -131,3 +135,20 @@ class RoFormerConfig(PretrainedConfig):
|
|||||||
self.layer_norm_eps = layer_norm_eps
|
self.layer_norm_eps = layer_norm_eps
|
||||||
self.rotary_value = rotary_value
|
self.rotary_value = rotary_value
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
|
|
||||||
|
|
||||||
|
class RoFormerOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task == "multiple-choice":
|
||||||
|
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
||||||
|
else:
|
||||||
|
dynamic_axis = {0: "batch", 1: "sequence"}
|
||||||
|
dynamic_axis = {0: "batch", 1: "sequence"}
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", dynamic_axis),
|
||||||
|
("attention_mask", dynamic_axis),
|
||||||
|
("token_type_ids", dynamic_axis),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from ..models.m2m_100 import M2M100OnnxConfig
|
|||||||
from ..models.marian import MarianOnnxConfig
|
from ..models.marian import MarianOnnxConfig
|
||||||
from ..models.mbart import MBartOnnxConfig
|
from ..models.mbart import MBartOnnxConfig
|
||||||
from ..models.roberta import RobertaOnnxConfig
|
from ..models.roberta import RobertaOnnxConfig
|
||||||
|
from ..models.roformer import RoFormerOnnxConfig
|
||||||
from ..models.t5 import T5OnnxConfig
|
from ..models.t5 import T5OnnxConfig
|
||||||
from ..models.vit import ViTOnnxConfig
|
from ..models.vit import ViTOnnxConfig
|
||||||
from ..models.xlm_roberta import XLMRobertaOnnxConfig
|
from ..models.xlm_roberta import XLMRobertaOnnxConfig
|
||||||
@@ -333,6 +334,17 @@ class FeaturesManager:
|
|||||||
"question-answering",
|
"question-answering",
|
||||||
onnx_config_cls=Data2VecTextOnnxConfig,
|
onnx_config_cls=Data2VecTextOnnxConfig,
|
||||||
),
|
),
|
||||||
|
"roformer": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"masked-lm",
|
||||||
|
"causal-lm",
|
||||||
|
"sequence-classification",
|
||||||
|
"token-classification",
|
||||||
|
"multiple-choice",
|
||||||
|
"question-answering",
|
||||||
|
"token-classification",
|
||||||
|
onnx_config_cls=RoFormerOnnxConfig,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))
|
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))
|
||||||
|
|||||||
@@ -179,6 +179,7 @@ PYTORCH_EXPORT_MODELS = {
|
|||||||
("distilbert", "distilbert-base-cased"),
|
("distilbert", "distilbert-base-cased"),
|
||||||
("electra", "google/electra-base-generator"),
|
("electra", "google/electra-base-generator"),
|
||||||
("roberta", "roberta-base"),
|
("roberta", "roberta-base"),
|
||||||
|
("roformer", "junnyu/roformer_chinese_base"),
|
||||||
("xlm-roberta", "xlm-roberta-base"),
|
("xlm-roberta", "xlm-roberta-base"),
|
||||||
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
||||||
("vit", "google/vit-base-patch16-224"),
|
("vit", "google/vit-base-patch16-224"),
|
||||||
|
|||||||
Reference in New Issue
Block a user