Adds IBERT to models exportable with ONNX (#14868)
* Add IBertOnnxConfig and tests * add all the supported features for IBERT and remove outputs in IbertOnnxConfig * use OnnxConfig * fix codestyle * remove serialization.rst * codestyle
This commit is contained in:
@@ -40,6 +40,7 @@ Ready-made configurations include the following models:
|
|||||||
- CamemBERT
|
- CamemBERT
|
||||||
- DistilBERT
|
- DistilBERT
|
||||||
- GPT Neo
|
- GPT Neo
|
||||||
|
- I-BERT
|
||||||
- LayoutLM
|
- LayoutLM
|
||||||
- Longformer
|
- Longformer
|
||||||
- Marian
|
- Marian
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_torch_available
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"],
|
"configuration_ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig", "IBertOnnxConfig"],
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -38,7 +38,7 @@ if is_torch_available():
|
|||||||
]
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
|
from .configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig, IBertOnnxConfig
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_ibert import (
|
from .modeling_ibert import (
|
||||||
|
|||||||
@@ -15,6 +15,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" I-BERT configuration"""
|
""" I-BERT configuration"""
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
|
from transformers.onnx import OnnxConfig
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
@@ -122,3 +126,14 @@ class IBertConfig(PretrainedConfig):
|
|||||||
self.position_embedding_type = position_embedding_type
|
self.position_embedding_type = position_embedding_type
|
||||||
self.quant_mode = quant_mode
|
self.quant_mode = quant_mode
|
||||||
self.force_dequant = force_dequant
|
self.force_dequant = force_dequant
|
||||||
|
|
||||||
|
|
||||||
|
class IBertOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from ..models.camembert import CamembertOnnxConfig
|
|||||||
from ..models.distilbert import DistilBertOnnxConfig
|
from ..models.distilbert import DistilBertOnnxConfig
|
||||||
from ..models.gpt2 import GPT2OnnxConfig
|
from ..models.gpt2 import GPT2OnnxConfig
|
||||||
from ..models.gpt_neo import GPTNeoOnnxConfig
|
from ..models.gpt_neo import GPTNeoOnnxConfig
|
||||||
|
from ..models.ibert import IBertOnnxConfig
|
||||||
from ..models.layoutlm import LayoutLMOnnxConfig
|
from ..models.layoutlm import LayoutLMOnnxConfig
|
||||||
from ..models.longformer import LongformerOnnxConfig
|
from ..models.longformer import LongformerOnnxConfig
|
||||||
from ..models.marian import MarianOnnxConfig
|
from ..models.marian import MarianOnnxConfig
|
||||||
@@ -125,6 +126,15 @@ class FeaturesManager:
|
|||||||
"question-answering",
|
"question-answering",
|
||||||
onnx_config_cls=BertOnnxConfig,
|
onnx_config_cls=BertOnnxConfig,
|
||||||
),
|
),
|
||||||
|
"ibert": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"masked-lm",
|
||||||
|
"sequence-classification",
|
||||||
|
# "multiple-choice",
|
||||||
|
"token-classification",
|
||||||
|
"question-answering",
|
||||||
|
onnx_config_cls=IBertOnnxConfig,
|
||||||
|
),
|
||||||
"camembert": supported_features_mapping(
|
"camembert": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
|
|||||||
@@ -171,6 +171,7 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
|
|||||||
PYTORCH_EXPORT_MODELS = {
|
PYTORCH_EXPORT_MODELS = {
|
||||||
("albert", "hf-internal-testing/tiny-albert"),
|
("albert", "hf-internal-testing/tiny-albert"),
|
||||||
("bert", "bert-base-cased"),
|
("bert", "bert-base-cased"),
|
||||||
|
("ibert", "kssteven/ibert-roberta-base"),
|
||||||
("camembert", "camembert-base"),
|
("camembert", "camembert-base"),
|
||||||
("distilbert", "distilbert-base-cased"),
|
("distilbert", "distilbert-base-cased"),
|
||||||
# ("longFormer", "longformer-base-4096"),
|
# ("longFormer", "longformer-base-4096"),
|
||||||
|
|||||||
Reference in New Issue
Block a user