Add ONNX export for BeiT (#16498)

* Add beit onnx conversion support

* Updated docs

* Added cross reference to ViT ONNX config
This commit is contained in:
Jim Rohrer
2022-04-01 03:52:42 -05:00
committed by GitHub
parent bfeff6cc6a
commit 9de70f213e
5 changed files with 31 additions and 5 deletions

View File

@@ -47,6 +47,7 @@ Ready-made configurations include the following architectures:
- ALBERT - ALBERT
- BART - BART
- BEiT
- BERT - BERT
- Blenderbot - Blenderbot
- BlenderbotSmall - BlenderbotSmall

View File

@@ -22,7 +22,7 @@ from ...utils import _LazyModule, is_flax_available, is_torch_available, is_visi
_import_structure = { _import_structure = {
"configuration_beit": ["BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BeitConfig"], "configuration_beit": ["BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BeitConfig", "BeitOnnxConfig"],
} }
if is_vision_available(): if is_vision_available():
@@ -48,7 +48,7 @@ if is_flax_available():
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_beit import BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BeitConfig from .configuration_beit import BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BeitConfig, BeitOnnxConfig
if is_vision_available(): if is_vision_available():
from .feature_extraction_beit import BeitFeatureExtractor from .feature_extraction_beit import BeitFeatureExtractor

View File

@@ -13,8 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" BEiT model configuration""" """ BEiT model configuration"""
from collections import OrderedDict
from typing import Mapping
from packaging import version
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
@@ -176,3 +181,21 @@ class BeitConfig(PretrainedConfig):
self.auxiliary_num_convs = auxiliary_num_convs self.auxiliary_num_convs = auxiliary_num_convs
self.auxiliary_concat_input = auxiliary_concat_input self.auxiliary_concat_input = auxiliary_concat_input
self.semantic_loss_ignore_index = semantic_loss_ignore_index self.semantic_loss_ignore_index = semantic_loss_ignore_index
# Copied from transformers.models.vit.configuration_vit.ViTOnnxConfig
class BeitOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "sequence"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-4

View File

@@ -4,6 +4,7 @@ from typing import Callable, Dict, Optional, Tuple, Type, Union
from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_available, is_torch_available from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_available, is_torch_available
from ..models.albert import AlbertOnnxConfig from ..models.albert import AlbertOnnxConfig
from ..models.bart import BartOnnxConfig from ..models.bart import BartOnnxConfig
from ..models.beit import BeitOnnxConfig
from ..models.bert import BertOnnxConfig from ..models.bert import BertOnnxConfig
from ..models.blenderbot import BlenderbotOnnxConfig from ..models.blenderbot import BlenderbotOnnxConfig
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
@@ -270,6 +271,7 @@ class FeaturesManager:
onnx_config_cls=ElectraOnnxConfig, onnx_config_cls=ElectraOnnxConfig,
), ),
"vit": supported_features_mapping("default", "image-classification", onnx_config_cls=ViTOnnxConfig), "vit": supported_features_mapping("default", "image-classification", onnx_config_cls=ViTOnnxConfig),
"beit": supported_features_mapping("default", "image-classification", onnx_config_cls=BeitOnnxConfig),
"blenderbot": supported_features_mapping( "blenderbot": supported_features_mapping(
"default", "default",
"default-with-past", "default-with-past",

View File

@@ -15,14 +15,13 @@ from transformers.onnx import (
export, export,
validate_model_outputs, validate_model_outputs,
) )
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
from transformers.testing_utils import require_onnx, require_tf, require_torch, require_vision, slow
if is_torch_available() or is_tf_available(): if is_torch_available() or is_tf_available():
from transformers.onnx.features import FeaturesManager from transformers.onnx.features import FeaturesManager
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
from transformers.testing_utils import require_onnx, require_tf, require_torch, require_vision, slow
@require_onnx @require_onnx
class OnnxUtilsTestCaseV2(TestCase): class OnnxUtilsTestCaseV2(TestCase):
@@ -181,6 +180,7 @@ PYTORCH_EXPORT_MODELS = {
("xlm-roberta", "xlm-roberta-base"), ("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"), ("layoutlm", "microsoft/layoutlm-base-uncased"),
("vit", "google/vit-base-patch16-224"), ("vit", "google/vit-base-patch16-224"),
("beit", "microsoft/beit-base-patch16-224"),
} }
PYTORCH_EXPORT_WITH_PAST_MODELS = { PYTORCH_EXPORT_WITH_PAST_MODELS = {