add ONNX support for LeVit (#18154)
Co-authored-by: Guilhem Chéron <guilhemc@authentifier.com>
This commit is contained in:
@@ -72,6 +72,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- I-BERT
|
- I-BERT
|
||||||
- LayoutLM
|
- LayoutLM
|
||||||
- LayoutLMv3
|
- LayoutLMv3
|
||||||
|
- LeViT
|
||||||
- LongT5
|
- LongT5
|
||||||
- M2M100
|
- M2M100
|
||||||
- Marian
|
- Marian
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING
|
|||||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {"configuration_levit": ["LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LevitConfig"]}
|
_import_structure = {"configuration_levit": ["LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LevitConfig", "LevitOnnxConfig"]}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_vision_available():
|
if not is_vision_available():
|
||||||
@@ -46,7 +46,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_levit import LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, LevitConfig
|
from .configuration_levit import LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, LevitConfig, LevitOnnxConfig
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_vision_available():
|
if not is_vision_available():
|
||||||
|
|||||||
@@ -14,7 +14,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" LeViT model configuration"""
|
""" LeViT model configuration"""
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -120,3 +126,21 @@ class LevitConfig(PretrainedConfig):
|
|||||||
["Subsample", key_dim[0], hidden_sizes[0] // key_dim[0], 4, 2, 2],
|
["Subsample", key_dim[0], hidden_sizes[0] // key_dim[0], 4, 2, 2],
|
||||||
["Subsample", key_dim[0], hidden_sizes[1] // key_dim[0], 4, 2, 2],
|
["Subsample", key_dim[0], hidden_sizes[1] // key_dim[0], 4, 2, 2],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.vit.configuration_vit.ViTOnnxConfig
|
||||||
|
class LevitOnnxConfig(OnnxConfig):
|
||||||
|
|
||||||
|
torch_onnx_minimum_version = version.parse("1.11")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("pixel_values", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def atol_for_validation(self) -> float:
|
||||||
|
return 1e-4
|
||||||
|
|||||||
@@ -333,6 +333,9 @@ class FeaturesManager:
|
|||||||
"token-classification",
|
"token-classification",
|
||||||
onnx_config_cls="models.layoutlmv3.LayoutLMv3OnnxConfig",
|
onnx_config_cls="models.layoutlmv3.LayoutLMv3OnnxConfig",
|
||||||
),
|
),
|
||||||
|
"levit": supported_features_mapping(
|
||||||
|
"default", "image-classification", onnx_config_cls="models.levit.LevitOnnxConfig"
|
||||||
|
),
|
||||||
"longt5": supported_features_mapping(
|
"longt5": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"default-with-past",
|
"default-with-past",
|
||||||
|
|||||||
@@ -196,6 +196,7 @@ PYTORCH_EXPORT_MODELS = {
|
|||||||
("xlm-roberta", "xlm-roberta-base"),
|
("xlm-roberta", "xlm-roberta-base"),
|
||||||
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
||||||
("layoutlmv3", "microsoft/layoutlmv3-base"),
|
("layoutlmv3", "microsoft/layoutlmv3-base"),
|
||||||
|
("levit", "facebook/levit-128S"),
|
||||||
("vit", "google/vit-base-patch16-224"),
|
("vit", "google/vit-base-patch16-224"),
|
||||||
("deit", "facebook/deit-small-patch16-224"),
|
("deit", "facebook/deit-small-patch16-224"),
|
||||||
("beit", "microsoft/beit-base-patch16-224"),
|
("beit", "microsoft/beit-base-patch16-224"),
|
||||||
|
|||||||
Reference in New Issue
Block a user