Added XLM onnx config (#17030)
* Add onnx configuration for xlm * Add supported features for xlm * Add xlm to models exportable with onnx * Add xlm architecture to test file * Modify docs * Make code quality fixes
This commit is contained in:
@@ -75,6 +75,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- RoFormer
|
- RoFormer
|
||||||
- T5
|
- T5
|
||||||
- ViT
|
- ViT
|
||||||
|
- XLM
|
||||||
- XLM-RoBERTa
|
- XLM-RoBERTa
|
||||||
- XLM-RoBERTa-XL
|
- XLM-RoBERTa-XL
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_availabl
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig"],
|
"configuration_xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMOnnxConfig"],
|
||||||
"tokenization_xlm": ["XLMTokenizer"],
|
"tokenization_xlm": ["XLMTokenizer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,7 +64,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
|
from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMOnnxConfig
|
||||||
from .tokenization_xlm import XLMTokenizer
|
from .tokenization_xlm import XLMTokenizer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -13,8 +13,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" XLM configuration"""
|
""" XLM configuration"""
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -228,3 +231,20 @@ class XLMConfig(PretrainedConfig):
|
|||||||
self.n_words = kwargs["n_words"]
|
self.n_words = kwargs["n_words"]
|
||||||
|
|
||||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs)
|
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig
|
||||||
|
class XLMOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task == "multiple-choice":
|
||||||
|
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
||||||
|
else:
|
||||||
|
dynamic_axis = {0: "batch", 1: "sequence"}
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", dynamic_axis),
|
||||||
|
("attention_mask", dynamic_axis),
|
||||||
|
("token_type_ids", dynamic_axis),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from ..models.roberta import RobertaOnnxConfig
|
|||||||
from ..models.roformer import RoFormerOnnxConfig
|
from ..models.roformer import RoFormerOnnxConfig
|
||||||
from ..models.t5 import T5OnnxConfig
|
from ..models.t5 import T5OnnxConfig
|
||||||
from ..models.vit import ViTOnnxConfig
|
from ..models.vit import ViTOnnxConfig
|
||||||
|
from ..models.xlm import XLMOnnxConfig
|
||||||
from ..models.xlm_roberta import XLMRobertaOnnxConfig
|
from ..models.xlm_roberta import XLMRobertaOnnxConfig
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
from .config import OnnxConfig
|
from .config import OnnxConfig
|
||||||
@@ -357,6 +358,16 @@ class FeaturesManager:
|
|||||||
"vit": supported_features_mapping(
|
"vit": supported_features_mapping(
|
||||||
"default", "image-classification", "masked-im", onnx_config_cls=ViTOnnxConfig
|
"default", "image-classification", "masked-im", onnx_config_cls=ViTOnnxConfig
|
||||||
),
|
),
|
||||||
|
"xlm": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"masked-lm",
|
||||||
|
"causal-lm",
|
||||||
|
"sequence-classification",
|
||||||
|
"multiple-choice",
|
||||||
|
"token-classification",
|
||||||
|
"question-answering",
|
||||||
|
onnx_config_cls=XLMOnnxConfig,
|
||||||
|
),
|
||||||
"xlm-roberta": supported_features_mapping(
|
"xlm-roberta": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
|
|||||||
@@ -181,6 +181,7 @@ PYTORCH_EXPORT_MODELS = {
|
|||||||
("roberta", "roberta-base"),
|
("roberta", "roberta-base"),
|
||||||
("roformer", "junnyu/roformer_chinese_base"),
|
("roformer", "junnyu/roformer_chinese_base"),
|
||||||
("mobilebert", "google/mobilebert-uncased"),
|
("mobilebert", "google/mobilebert-uncased"),
|
||||||
|
("xlm", "xlm-clm-ende-1024"),
|
||||||
("xlm-roberta", "xlm-roberta-base"),
|
("xlm-roberta", "xlm-roberta-base"),
|
||||||
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
||||||
("vit", "google/vit-base-patch16-224"),
|
("vit", "google/vit-base-patch16-224"),
|
||||||
|
|||||||
Reference in New Issue
Block a user