Add Onnx Config for PoolFormer (#20868)
poolformer onnx Co-authored-by: syed <syed.abdul@sandlogic.com>
This commit is contained in:
committed by
GitHub
parent
4a4cd6cd02
commit
15bc776fec
@@ -102,6 +102,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- OWL-ViT
|
- OWL-ViT
|
||||||
- Perceiver
|
- Perceiver
|
||||||
- PLBart
|
- PLBart
|
||||||
|
- PoolFormer
|
||||||
- RemBERT
|
- RemBERT
|
||||||
- ResNet
|
- ResNet
|
||||||
- RoBERTa
|
- RoBERTa
|
||||||
|
|||||||
@@ -21,7 +21,13 @@ from typing import TYPE_CHECKING
|
|||||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {"configuration_poolformer": ["POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PoolFormerConfig"]}
|
_import_structure = {
|
||||||
|
"configuration_poolformer": [
|
||||||
|
"POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||||
|
"PoolFormerConfig",
|
||||||
|
"PoolFormerOnnxConfig",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_vision_available():
|
if not is_vision_available():
|
||||||
@@ -47,7 +53,11 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_poolformer import POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, PoolFormerConfig
|
from .configuration_poolformer import (
|
||||||
|
POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
PoolFormerConfig,
|
||||||
|
PoolFormerOnnxConfig,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_vision_available():
|
if not is_vision_available():
|
||||||
|
|||||||
@@ -13,8 +13,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PoolFormer model configuration"""
|
""" PoolFormer model configuration"""
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -125,3 +130,20 @@ class PoolFormerConfig(PretrainedConfig):
|
|||||||
self.layer_scale_init_value = layer_scale_init_value
|
self.layer_scale_init_value = layer_scale_init_value
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class PoolFormerOnnxConfig(OnnxConfig):
|
||||||
|
|
||||||
|
torch_onnx_minimum_version = version.parse("1.11")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def atol_for_validation(self) -> float:
|
||||||
|
return 2e-3
|
||||||
|
|||||||
@@ -447,6 +447,9 @@ class FeaturesManager:
|
|||||||
"sequence-classification",
|
"sequence-classification",
|
||||||
onnx_config_cls="models.perceiver.PerceiverOnnxConfig",
|
onnx_config_cls="models.perceiver.PerceiverOnnxConfig",
|
||||||
),
|
),
|
||||||
|
"poolformer": supported_features_mapping(
|
||||||
|
"default", "image-classification", onnx_config_cls="models.poolformer.PoolFormerOnnxConfig"
|
||||||
|
),
|
||||||
"rembert": supported_features_mapping(
|
"rembert": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
|
|||||||
@@ -210,6 +210,7 @@ PYTORCH_EXPORT_MODELS = {
|
|||||||
("owlvit", "google/owlvit-base-patch32"),
|
("owlvit", "google/owlvit-base-patch32"),
|
||||||
("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("masked-lm", "sequence-classification")),
|
("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("masked-lm", "sequence-classification")),
|
||||||
("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("image-classification",)),
|
("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("image-classification",)),
|
||||||
|
("poolformer", "sail/poolformer_s12"),
|
||||||
("rembert", "google/rembert"),
|
("rembert", "google/rembert"),
|
||||||
("resnet", "microsoft/resnet-50"),
|
("resnet", "microsoft/resnet-50"),
|
||||||
("roberta", "hf-internal-testing/tiny-random-RobertaModel"),
|
("roberta", "hf-internal-testing/tiny-random-RobertaModel"),
|
||||||
|
|||||||
Reference in New Issue
Block a user