Add ONNX support for ConvNeXT (#17627)
This commit is contained in:
@@ -55,6 +55,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- BlenderbotSmall
|
- BlenderbotSmall
|
||||||
- CamemBERT
|
- CamemBERT
|
||||||
- ConvBERT
|
- ConvBERT
|
||||||
|
- ConvNeXT
|
||||||
- Data2VecText
|
- Data2VecText
|
||||||
- Data2VecVision
|
- Data2VecVision
|
||||||
- DeiT
|
- DeiT
|
||||||
|
|||||||
@@ -27,7 +27,9 @@ from ...utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {"configuration_convnext": ["CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvNextConfig"]}
|
_import_structure = {
|
||||||
|
"configuration_convnext": ["CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvNextConfig", "ConvNextOnnxConfig"]
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_vision_available():
|
if not is_vision_available():
|
||||||
@@ -63,7 +65,7 @@ else:
|
|||||||
]
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_convnext import CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextConfig
|
from .configuration_convnext import CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextConfig, ConvNextOnnxConfig
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_vision_available():
|
if not is_vision_available():
|
||||||
|
|||||||
@@ -14,7 +14,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" ConvNeXT model configuration"""
|
""" ConvNeXT model configuration"""
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -101,3 +107,20 @@ class ConvNextConfig(PretrainedConfig):
|
|||||||
self.layer_scale_init_value = layer_scale_init_value
|
self.layer_scale_init_value = layer_scale_init_value
|
||||||
self.drop_path_rate = drop_path_rate
|
self.drop_path_rate = drop_path_rate
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
|
|
||||||
|
|
||||||
|
class ConvNextOnnxConfig(OnnxConfig):
|
||||||
|
|
||||||
|
torch_onnx_minimum_version = version.parse("1.11")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("pixel_values", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def atol_for_validation(self) -> float:
|
||||||
|
return 1e-5
|
||||||
|
|||||||
@@ -193,6 +193,11 @@ class FeaturesManager:
|
|||||||
"question-answering",
|
"question-answering",
|
||||||
onnx_config_cls="models.convbert.ConvBertOnnxConfig",
|
onnx_config_cls="models.convbert.ConvBertOnnxConfig",
|
||||||
),
|
),
|
||||||
|
"convnext": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"image-classification",
|
||||||
|
onnx_config_cls="models.convnext.ConvNextOnnxConfig",
|
||||||
|
),
|
||||||
"data2vec-text": supported_features_mapping(
|
"data2vec-text": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
|
|||||||
@@ -180,6 +180,7 @@ PYTORCH_EXPORT_MODELS = {
|
|||||||
("ibert", "kssteven/ibert-roberta-base"),
|
("ibert", "kssteven/ibert-roberta-base"),
|
||||||
("camembert", "camembert-base"),
|
("camembert", "camembert-base"),
|
||||||
("convbert", "YituTech/conv-bert-base"),
|
("convbert", "YituTech/conv-bert-base"),
|
||||||
|
("convnext", "facebook/convnext-tiny-224"),
|
||||||
("distilbert", "distilbert-base-cased"),
|
("distilbert", "distilbert-base-cased"),
|
||||||
("electra", "google/electra-base-generator"),
|
("electra", "google/electra-base-generator"),
|
||||||
("resnet", "microsoft/resnet-50"),
|
("resnet", "microsoft/resnet-50"),
|
||||||
|
|||||||
Reference in New Issue
Block a user