Add OnnxConfig for SqueezeBert iss17314 (#17315)
* add onnx config for SqueezeBert * add test for onnx config for SqueezeBert * add automatically updated doc for onnx config for SqueezeBert * Update src/transformers/onnx/features.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update src/transformers/models/squeezebert/configuration_squeezebert.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
@@ -73,6 +73,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- PLBart
|
- PLBart
|
||||||
- RoBERTa
|
- RoBERTa
|
||||||
- RoFormer
|
- RoFormer
|
||||||
|
- SqueezeBERT
|
||||||
- T5
|
- T5
|
||||||
- ViT
|
- ViT
|
||||||
- XLM
|
- XLM
|
||||||
|
|||||||
@@ -22,7 +22,11 @@ from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig"],
|
"configuration_squeezebert": [
|
||||||
|
"SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||||
|
"SqueezeBertConfig",
|
||||||
|
"SqueezeBertOnnxConfig",
|
||||||
|
],
|
||||||
"tokenization_squeezebert": ["SqueezeBertTokenizer"],
|
"tokenization_squeezebert": ["SqueezeBertTokenizer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,7 +58,11 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig
|
from .configuration_squeezebert import (
|
||||||
|
SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
SqueezeBertConfig,
|
||||||
|
SqueezeBertOnnxConfig,
|
||||||
|
)
|
||||||
from .tokenization_squeezebert import SqueezeBertTokenizer
|
from .tokenization_squeezebert import SqueezeBertTokenizer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -13,8 +13,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" SqueezeBERT model configuration"""
|
""" SqueezeBERT model configuration"""
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -154,3 +157,20 @@ class SqueezeBertConfig(PretrainedConfig):
|
|||||||
self.post_attention_groups = post_attention_groups
|
self.post_attention_groups = post_attention_groups
|
||||||
self.intermediate_groups = intermediate_groups
|
self.intermediate_groups = intermediate_groups
|
||||||
self.output_groups = output_groups
|
self.output_groups = output_groups
|
||||||
|
|
||||||
|
|
||||||
|
# # Copied from transformers.models.bert.configuration_bert.BertOnxxConfig with Bert->SqueezeBert
|
||||||
|
class SqueezeBertOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task == "multiple-choice":
|
||||||
|
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
||||||
|
else:
|
||||||
|
dynamic_axis = {0: "batch", 1: "sequence"}
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", dynamic_axis),
|
||||||
|
("attention_mask", dynamic_axis),
|
||||||
|
("token_type_ids", dynamic_axis),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from ..models.mbart import MBartOnnxConfig
|
|||||||
from ..models.mobilebert import MobileBertOnnxConfig
|
from ..models.mobilebert import MobileBertOnnxConfig
|
||||||
from ..models.roberta import RobertaOnnxConfig
|
from ..models.roberta import RobertaOnnxConfig
|
||||||
from ..models.roformer import RoFormerOnnxConfig
|
from ..models.roformer import RoFormerOnnxConfig
|
||||||
|
from ..models.squeezebert import SqueezeBertOnnxConfig
|
||||||
from ..models.t5 import T5OnnxConfig
|
from ..models.t5 import T5OnnxConfig
|
||||||
from ..models.vit import ViTOnnxConfig
|
from ..models.vit import ViTOnnxConfig
|
||||||
from ..models.xlm import XLMOnnxConfig
|
from ..models.xlm import XLMOnnxConfig
|
||||||
@@ -352,6 +353,15 @@ class FeaturesManager:
|
|||||||
"token-classification",
|
"token-classification",
|
||||||
onnx_config_cls=RoFormerOnnxConfig,
|
onnx_config_cls=RoFormerOnnxConfig,
|
||||||
),
|
),
|
||||||
|
"squeezebert": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"masked-lm",
|
||||||
|
"sequence-classification",
|
||||||
|
"multiple-choice",
|
||||||
|
"token-classification",
|
||||||
|
"question-answering",
|
||||||
|
onnx_config_cls=SqueezeBertOnnxConfig,
|
||||||
|
),
|
||||||
"t5": supported_features_mapping(
|
"t5": supported_features_mapping(
|
||||||
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig
|
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -180,6 +180,7 @@ PYTORCH_EXPORT_MODELS = {
|
|||||||
("electra", "google/electra-base-generator"),
|
("electra", "google/electra-base-generator"),
|
||||||
("roberta", "roberta-base"),
|
("roberta", "roberta-base"),
|
||||||
("roformer", "junnyu/roformer_chinese_base"),
|
("roformer", "junnyu/roformer_chinese_base"),
|
||||||
|
("squeezebert", "squeezebert/squeezebert-uncased"),
|
||||||
("mobilebert", "google/mobilebert-uncased"),
|
("mobilebert", "google/mobilebert-uncased"),
|
||||||
("xlm", "xlm-clm-ende-1024"),
|
("xlm", "xlm-clm-ende-1024"),
|
||||||
("xlm-roberta", "xlm-roberta-base"),
|
("xlm-roberta", "xlm-roberta-base"),
|
||||||
|
|||||||
Reference in New Issue
Block a user