added deit onnx config (#16887)

* added deit onnx config
This commit is contained in:
Rushi Chaudhari
2022-04-25 14:50:45 -04:00
committed by GitHub
parent 9331b37967
commit 8246caf3eb
5 changed files with 38 additions and 4 deletions

View File

@@ -56,6 +56,7 @@ Ready-made configurations include the following architectures:
- ConvBERT
- Data2VecText
- Data2VecVision
- DeiT
- DistilBERT
- ELECTRA
- FlauBERT

View File

@@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_torch_available, is_vision_available
_import_structure = {
"configuration_deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig"],
"configuration_deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig", "DeiTOnnxConfig"],
}
if is_vision_available():
@@ -39,7 +39,7 @@ if is_torch_available():
if TYPE_CHECKING:
from .configuration_deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig
from .configuration_deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig, DeiTOnnxConfig
if is_vision_available():
from .feature_extraction_deit import DeiTFeatureExtractor

View File

@@ -14,7 +14,13 @@
# limitations under the License.
""" DeiT model configuration"""
from collections import OrderedDict
from typing import Mapping
from packaging import version
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
@@ -120,3 +126,19 @@ class DeiTConfig(PretrainedConfig):
self.num_channels = num_channels
self.qkv_bias = qkv_bias
self.encoder_stride = encoder_stride
class DeiTOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "sequence"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-4

View File

@@ -12,6 +12,7 @@ from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
from ..models.camembert import CamembertOnnxConfig
from ..models.convbert import ConvBertOnnxConfig
from ..models.data2vec import Data2VecTextOnnxConfig
from ..models.deit import DeiTOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig
from ..models.electra import ElectraOnnxConfig
from ..models.flaubert import FlaubertOnnxConfig
@@ -38,6 +39,7 @@ if is_torch_available():
AutoModel,
AutoModelForCausalLM,
AutoModelForImageClassification,
AutoModelForMaskedImageModeling,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
AutoModelForQuestionAnswering,
@@ -103,6 +105,7 @@ class FeaturesManager:
"multiple-choice": AutoModelForMultipleChoice,
"question-answering": AutoModelForQuestionAnswering,
"image-classification": AutoModelForImageClassification,
"masked-im": AutoModelForMaskedImageModeling,
}
if is_tf_available():
_TASKS_TO_TF_AUTOMODELS = {
@@ -294,8 +297,15 @@ class FeaturesManager:
"question-answering",
onnx_config_cls=ElectraOnnxConfig,
),
"vit": supported_features_mapping("default", "image-classification", onnx_config_cls=ViTOnnxConfig),
"beit": supported_features_mapping("default", "image-classification", onnx_config_cls=BeitOnnxConfig),
"vit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=ViTOnnxConfig
),
"beit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=BeitOnnxConfig
),
"deit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=DeiTOnnxConfig
),
"blenderbot": supported_features_mapping(
"default",
"default-with-past",

View File

@@ -182,6 +182,7 @@ PYTORCH_EXPORT_MODELS = {
("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"),
("vit", "google/vit-base-patch16-224"),
("deit", "facebook/deit-small-patch16-224"),
("beit", "microsoft/beit-base-patch16-224"),
("data2vec-text", "facebook/data2vec-text-base"),
}