@@ -56,6 +56,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- ConvBERT
|
- ConvBERT
|
||||||
- Data2VecText
|
- Data2VecText
|
||||||
- Data2VecVision
|
- Data2VecVision
|
||||||
|
- DeiT
|
||||||
- DistilBERT
|
- DistilBERT
|
||||||
- ELECTRA
|
- ELECTRA
|
||||||
- FlauBERT
|
- FlauBERT
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_torch_available, is_vision_available
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig"],
|
"configuration_deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig", "DeiTOnnxConfig"],
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@@ -39,7 +39,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig
|
from .configuration_deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig, DeiTOnnxConfig
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from .feature_extraction_deit import DeiTFeatureExtractor
|
from .feature_extraction_deit import DeiTFeatureExtractor
|
||||||
|
|||||||
@@ -14,7 +14,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" DeiT model configuration"""
|
""" DeiT model configuration"""
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -120,3 +126,19 @@ class DeiTConfig(PretrainedConfig):
|
|||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
self.qkv_bias = qkv_bias
|
self.qkv_bias = qkv_bias
|
||||||
self.encoder_stride = encoder_stride
|
self.encoder_stride = encoder_stride
|
||||||
|
|
||||||
|
|
||||||
|
class DeiTOnnxConfig(OnnxConfig):
|
||||||
|
torch_onnx_minimum_version = version.parse("1.11")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("pixel_values", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def atol_for_validation(self) -> float:
|
||||||
|
return 1e-4
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
|
|||||||
from ..models.camembert import CamembertOnnxConfig
|
from ..models.camembert import CamembertOnnxConfig
|
||||||
from ..models.convbert import ConvBertOnnxConfig
|
from ..models.convbert import ConvBertOnnxConfig
|
||||||
from ..models.data2vec import Data2VecTextOnnxConfig
|
from ..models.data2vec import Data2VecTextOnnxConfig
|
||||||
|
from ..models.deit import DeiTOnnxConfig
|
||||||
from ..models.distilbert import DistilBertOnnxConfig
|
from ..models.distilbert import DistilBertOnnxConfig
|
||||||
from ..models.electra import ElectraOnnxConfig
|
from ..models.electra import ElectraOnnxConfig
|
||||||
from ..models.flaubert import FlaubertOnnxConfig
|
from ..models.flaubert import FlaubertOnnxConfig
|
||||||
@@ -38,6 +39,7 @@ if is_torch_available():
|
|||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForImageClassification,
|
AutoModelForImageClassification,
|
||||||
|
AutoModelForMaskedImageModeling,
|
||||||
AutoModelForMaskedLM,
|
AutoModelForMaskedLM,
|
||||||
AutoModelForMultipleChoice,
|
AutoModelForMultipleChoice,
|
||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
@@ -103,6 +105,7 @@ class FeaturesManager:
|
|||||||
"multiple-choice": AutoModelForMultipleChoice,
|
"multiple-choice": AutoModelForMultipleChoice,
|
||||||
"question-answering": AutoModelForQuestionAnswering,
|
"question-answering": AutoModelForQuestionAnswering,
|
||||||
"image-classification": AutoModelForImageClassification,
|
"image-classification": AutoModelForImageClassification,
|
||||||
|
"masked-im": AutoModelForMaskedImageModeling,
|
||||||
}
|
}
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
_TASKS_TO_TF_AUTOMODELS = {
|
_TASKS_TO_TF_AUTOMODELS = {
|
||||||
@@ -294,8 +297,15 @@ class FeaturesManager:
|
|||||||
"question-answering",
|
"question-answering",
|
||||||
onnx_config_cls=ElectraOnnxConfig,
|
onnx_config_cls=ElectraOnnxConfig,
|
||||||
),
|
),
|
||||||
"vit": supported_features_mapping("default", "image-classification", onnx_config_cls=ViTOnnxConfig),
|
"vit": supported_features_mapping(
|
||||||
"beit": supported_features_mapping("default", "image-classification", onnx_config_cls=BeitOnnxConfig),
|
"default", "image-classification", "masked-im", onnx_config_cls=ViTOnnxConfig
|
||||||
|
),
|
||||||
|
"beit": supported_features_mapping(
|
||||||
|
"default", "image-classification", "masked-im", onnx_config_cls=BeitOnnxConfig
|
||||||
|
),
|
||||||
|
"deit": supported_features_mapping(
|
||||||
|
"default", "image-classification", "masked-im", onnx_config_cls=DeiTOnnxConfig
|
||||||
|
),
|
||||||
"blenderbot": supported_features_mapping(
|
"blenderbot": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"default-with-past",
|
"default-with-past",
|
||||||
|
|||||||
@@ -182,6 +182,7 @@ PYTORCH_EXPORT_MODELS = {
|
|||||||
("xlm-roberta", "xlm-roberta-base"),
|
("xlm-roberta", "xlm-roberta-base"),
|
||||||
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
||||||
("vit", "google/vit-base-patch16-224"),
|
("vit", "google/vit-base-patch16-224"),
|
||||||
|
("deit", "facebook/deit-small-patch16-224"),
|
||||||
("beit", "microsoft/beit-base-patch16-224"),
|
("beit", "microsoft/beit-base-patch16-224"),
|
||||||
("data2vec-text", "facebook/data2vec-text-base"),
|
("data2vec-text", "facebook/data2vec-text-base"),
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user