Add ONNX support for DETR (#17904)
This commit is contained in:
@@ -62,6 +62,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- DeBERTa
|
- DeBERTa
|
||||||
- DeBERTa-v2
|
- DeBERTa-v2
|
||||||
- DeiT
|
- DeiT
|
||||||
|
- DETR
|
||||||
- DistilBERT
|
- DistilBERT
|
||||||
- ELECTRA
|
- ELECTRA
|
||||||
- FlauBERT
|
- FlauBERT
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING
|
|||||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_timm_available, is_vision_available
|
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_timm_available, is_vision_available
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {"configuration_detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig"]}
|
_import_structure = {"configuration_detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig", "DetrOnnxConfig"]}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_vision_available():
|
if not is_vision_available():
|
||||||
@@ -47,7 +47,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig
|
from .configuration_detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig, DetrOnnxConfig
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_vision_available():
|
if not is_vision_available():
|
||||||
|
|||||||
@@ -14,7 +14,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" DETR model configuration"""
|
""" DETR model configuration"""
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -204,3 +210,25 @@ class DetrConfig(PretrainedConfig):
|
|||||||
@property
|
@property
|
||||||
def hidden_size(self) -> int:
|
def hidden_size(self) -> int:
|
||||||
return self.d_model
|
return self.d_model
|
||||||
|
|
||||||
|
|
||||||
|
class DetrOnnxConfig(OnnxConfig):
|
||||||
|
|
||||||
|
torch_onnx_minimum_version = version.parse("1.11")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("pixel_values", {0: "batch", 1: "sequence"}),
|
||||||
|
("pixel_mask", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def atol_for_validation(self) -> float:
|
||||||
|
return 1e-5
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_onnx_opset(self) -> int:
|
||||||
|
return 12
|
||||||
|
|||||||
@@ -77,9 +77,22 @@ class OnnxConfig(ABC):
|
|||||||
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
||||||
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
|
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
|
||||||
"image-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
"image-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
||||||
|
"image-segmentation": OrderedDict(
|
||||||
|
{
|
||||||
|
"logits": {0: "batch", 1: "sequence"},
|
||||||
|
"pred_boxes": {0: "batch", 1: "sequence"},
|
||||||
|
"pred_masks": {0: "batch", 1: "sequence"},
|
||||||
|
}
|
||||||
|
),
|
||||||
"masked-im": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
"masked-im": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
||||||
"masked-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
"masked-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
||||||
"multiple-choice": OrderedDict({"logits": {0: "batch"}}),
|
"multiple-choice": OrderedDict({"logits": {0: "batch"}}),
|
||||||
|
"object-detection": OrderedDict(
|
||||||
|
{
|
||||||
|
"logits": {0: "batch", 1: "sequence"},
|
||||||
|
"pred_boxes": {0: "batch", 1: "sequence"},
|
||||||
|
}
|
||||||
|
),
|
||||||
"question-answering": OrderedDict(
|
"question-answering": OrderedDict(
|
||||||
{
|
{
|
||||||
"start_logits": {0: "batch", 1: "sequence"},
|
"start_logits": {0: "batch", 1: "sequence"},
|
||||||
|
|||||||
@@ -15,9 +15,11 @@ if is_torch_available():
|
|||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForImageClassification,
|
AutoModelForImageClassification,
|
||||||
|
AutoModelForImageSegmentation,
|
||||||
AutoModelForMaskedImageModeling,
|
AutoModelForMaskedImageModeling,
|
||||||
AutoModelForMaskedLM,
|
AutoModelForMaskedLM,
|
||||||
AutoModelForMultipleChoice,
|
AutoModelForMultipleChoice,
|
||||||
|
AutoModelForObjectDetection,
|
||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
@@ -83,8 +85,10 @@ class FeaturesManager:
|
|||||||
"sequence-classification": AutoModelForSequenceClassification,
|
"sequence-classification": AutoModelForSequenceClassification,
|
||||||
"token-classification": AutoModelForTokenClassification,
|
"token-classification": AutoModelForTokenClassification,
|
||||||
"multiple-choice": AutoModelForMultipleChoice,
|
"multiple-choice": AutoModelForMultipleChoice,
|
||||||
|
"object-detection": AutoModelForObjectDetection,
|
||||||
"question-answering": AutoModelForQuestionAnswering,
|
"question-answering": AutoModelForQuestionAnswering,
|
||||||
"image-classification": AutoModelForImageClassification,
|
"image-classification": AutoModelForImageClassification,
|
||||||
|
"image-segmentation": AutoModelForImageSegmentation,
|
||||||
"masked-im": AutoModelForMaskedImageModeling,
|
"masked-im": AutoModelForMaskedImageModeling,
|
||||||
}
|
}
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
@@ -227,6 +231,12 @@ class FeaturesManager:
|
|||||||
"deit": supported_features_mapping(
|
"deit": supported_features_mapping(
|
||||||
"default", "image-classification", "masked-im", onnx_config_cls="models.deit.DeiTOnnxConfig"
|
"default", "image-classification", "masked-im", onnx_config_cls="models.deit.DeiTOnnxConfig"
|
||||||
),
|
),
|
||||||
|
"detr": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"object-detection",
|
||||||
|
"image-segmentation",
|
||||||
|
onnx_config_cls="models.detr.DetrOnnxConfig",
|
||||||
|
),
|
||||||
"distilbert": supported_features_mapping(
|
"distilbert": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
|
|||||||
@@ -183,6 +183,7 @@ PYTORCH_EXPORT_MODELS = {
|
|||||||
("deberta", "microsoft/deberta-base"),
|
("deberta", "microsoft/deberta-base"),
|
||||||
("deberta-v2", "microsoft/deberta-v2-xlarge"),
|
("deberta-v2", "microsoft/deberta-v2-xlarge"),
|
||||||
("convnext", "facebook/convnext-tiny-224"),
|
("convnext", "facebook/convnext-tiny-224"),
|
||||||
|
("detr", "facebook/detr-resnet-50"),
|
||||||
("distilbert", "distilbert-base-cased"),
|
("distilbert", "distilbert-base-cased"),
|
||||||
("electra", "google/electra-base-generator"),
|
("electra", "google/electra-base-generator"),
|
||||||
("resnet", "microsoft/resnet-50"),
|
("resnet", "microsoft/resnet-50"),
|
||||||
|
|||||||
Reference in New Issue
Block a user