Added XLM onnx config (#17030)

* Add onnx configuration for xlm

* Add supported features for xlm

* Add xlm to models exportable with onnx

* Add xlm architecture to test file

* Modify docs

* Make code quality fixes
This commit is contained in:
Ritik Nandwal
2022-05-31 18:56:06 +05:30
committed by GitHub
parent 567d9c061d
commit 5af38953bb
5 changed files with 35 additions and 2 deletions

View File

@@ -75,6 +75,7 @@ Ready-made configurations include the following architectures:
- RoFormer - RoFormer
- T5 - T5
- ViT - ViT
- XLM
- XLM-RoBERTa - XLM-RoBERTa
- XLM-RoBERTa-XL - XLM-RoBERTa-XL

View File

@@ -22,7 +22,7 @@ from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_availabl
_import_structure = { _import_structure = {
"configuration_xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig"], "configuration_xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMOnnxConfig"],
"tokenization_xlm": ["XLMTokenizer"], "tokenization_xlm": ["XLMTokenizer"],
} }
@@ -64,7 +64,7 @@ else:
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMOnnxConfig
from .tokenization_xlm import XLMTokenizer from .tokenization_xlm import XLMTokenizer
try: try:

View File

@@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" XLM configuration""" """ XLM configuration"""
from collections import OrderedDict
from typing import Mapping
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
@@ -228,3 +231,20 @@ class XLMConfig(PretrainedConfig):
self.n_words = kwargs["n_words"] self.n_words = kwargs["n_words"]
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs) super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs)
# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig
class XLMOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)

View File

@@ -30,6 +30,7 @@ from ..models.roberta import RobertaOnnxConfig
from ..models.roformer import RoFormerOnnxConfig from ..models.roformer import RoFormerOnnxConfig
from ..models.t5 import T5OnnxConfig from ..models.t5 import T5OnnxConfig
from ..models.vit import ViTOnnxConfig from ..models.vit import ViTOnnxConfig
from ..models.xlm import XLMOnnxConfig
from ..models.xlm_roberta import XLMRobertaOnnxConfig from ..models.xlm_roberta import XLMRobertaOnnxConfig
from ..utils import logging from ..utils import logging
from .config import OnnxConfig from .config import OnnxConfig
@@ -357,6 +358,16 @@ class FeaturesManager:
"vit": supported_features_mapping( "vit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=ViTOnnxConfig "default", "image-classification", "masked-im", onnx_config_cls=ViTOnnxConfig
), ),
"xlm": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=XLMOnnxConfig,
),
"xlm-roberta": supported_features_mapping( "xlm-roberta": supported_features_mapping(
"default", "default",
"masked-lm", "masked-lm",

View File

@@ -181,6 +181,7 @@ PYTORCH_EXPORT_MODELS = {
("roberta", "roberta-base"), ("roberta", "roberta-base"),
("roformer", "junnyu/roformer_chinese_base"), ("roformer", "junnyu/roformer_chinese_base"),
("mobilebert", "google/mobilebert-uncased"), ("mobilebert", "google/mobilebert-uncased"),
("xlm", "xlm-clm-ende-1024"),
("xlm-roberta", "xlm-roberta-base"), ("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"), ("layoutlm", "microsoft/layoutlm-base-uncased"),
("vit", "google/vit-base-patch16-224"), ("vit", "google/vit-base-patch16-224"),