Add ONNX export for BeiT (#16498)
* Add beit onnx conversion support * Updated docs * Added cross reference to ViT ONNX config
This commit is contained in:
@@ -47,6 +47,7 @@ Ready-made configurations include the following architectures:
|
|||||||
|
|
||||||
- ALBERT
|
- ALBERT
|
||||||
- BART
|
- BART
|
||||||
|
- BEiT
|
||||||
- BERT
|
- BERT
|
||||||
- Blenderbot
|
- Blenderbot
|
||||||
- BlenderbotSmall
|
- BlenderbotSmall
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from ...utils import _LazyModule, is_flax_available, is_torch_available, is_visi
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_beit": ["BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BeitConfig"],
|
"configuration_beit": ["BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BeitConfig", "BeitOnnxConfig"],
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@@ -48,7 +48,7 @@ if is_flax_available():
|
|||||||
]
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_beit import BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BeitConfig
|
from .configuration_beit import BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BeitConfig, BeitOnnxConfig
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from .feature_extraction_beit import BeitFeatureExtractor
|
from .feature_extraction_beit import BeitFeatureExtractor
|
||||||
|
|||||||
@@ -13,8 +13,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" BEiT model configuration"""
|
""" BEiT model configuration"""
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -176,3 +181,21 @@ class BeitConfig(PretrainedConfig):
|
|||||||
self.auxiliary_num_convs = auxiliary_num_convs
|
self.auxiliary_num_convs = auxiliary_num_convs
|
||||||
self.auxiliary_concat_input = auxiliary_concat_input
|
self.auxiliary_concat_input = auxiliary_concat_input
|
||||||
self.semantic_loss_ignore_index = semantic_loss_ignore_index
|
self.semantic_loss_ignore_index = semantic_loss_ignore_index
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.vit.configuration_vit.ViTOnnxConfig
|
||||||
|
class BeitOnnxConfig(OnnxConfig):
|
||||||
|
|
||||||
|
torch_onnx_minimum_version = version.parse("1.11")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("pixel_values", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def atol_for_validation(self) -> float:
|
||||||
|
return 1e-4
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from typing import Callable, Dict, Optional, Tuple, Type, Union
|
|||||||
from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_available, is_torch_available
|
from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_available, is_torch_available
|
||||||
from ..models.albert import AlbertOnnxConfig
|
from ..models.albert import AlbertOnnxConfig
|
||||||
from ..models.bart import BartOnnxConfig
|
from ..models.bart import BartOnnxConfig
|
||||||
|
from ..models.beit import BeitOnnxConfig
|
||||||
from ..models.bert import BertOnnxConfig
|
from ..models.bert import BertOnnxConfig
|
||||||
from ..models.blenderbot import BlenderbotOnnxConfig
|
from ..models.blenderbot import BlenderbotOnnxConfig
|
||||||
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
|
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
|
||||||
@@ -270,6 +271,7 @@ class FeaturesManager:
|
|||||||
onnx_config_cls=ElectraOnnxConfig,
|
onnx_config_cls=ElectraOnnxConfig,
|
||||||
),
|
),
|
||||||
"vit": supported_features_mapping("default", "image-classification", onnx_config_cls=ViTOnnxConfig),
|
"vit": supported_features_mapping("default", "image-classification", onnx_config_cls=ViTOnnxConfig),
|
||||||
|
"beit": supported_features_mapping("default", "image-classification", onnx_config_cls=BeitOnnxConfig),
|
||||||
"blenderbot": supported_features_mapping(
|
"blenderbot": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"default-with-past",
|
"default-with-past",
|
||||||
|
|||||||
@@ -15,14 +15,13 @@ from transformers.onnx import (
|
|||||||
export,
|
export,
|
||||||
validate_model_outputs,
|
validate_model_outputs,
|
||||||
)
|
)
|
||||||
|
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
|
||||||
|
from transformers.testing_utils import require_onnx, require_tf, require_torch, require_vision, slow
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available() or is_tf_available():
|
if is_torch_available() or is_tf_available():
|
||||||
from transformers.onnx.features import FeaturesManager
|
from transformers.onnx.features import FeaturesManager
|
||||||
|
|
||||||
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
|
|
||||||
from transformers.testing_utils import require_onnx, require_tf, require_torch, require_vision, slow
|
|
||||||
|
|
||||||
|
|
||||||
@require_onnx
|
@require_onnx
|
||||||
class OnnxUtilsTestCaseV2(TestCase):
|
class OnnxUtilsTestCaseV2(TestCase):
|
||||||
@@ -181,6 +180,7 @@ PYTORCH_EXPORT_MODELS = {
|
|||||||
("xlm-roberta", "xlm-roberta-base"),
|
("xlm-roberta", "xlm-roberta-base"),
|
||||||
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
||||||
("vit", "google/vit-base-patch16-224"),
|
("vit", "google/vit-base-patch16-224"),
|
||||||
|
("beit", "microsoft/beit-base-patch16-224"),
|
||||||
}
|
}
|
||||||
|
|
||||||
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
||||||
|
|||||||
Reference in New Issue
Block a user