diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index 790b23095a..f5c113e310 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -62,6 +62,7 @@ Ready-made configurations include the following architectures: - DeBERTa - DeBERTa-v2 - DeiT +- DETR - DistilBERT - ELECTRA - FlauBERT diff --git a/src/transformers/models/detr/__init__.py b/src/transformers/models/detr/__init__.py index 5958418807..b9b6d30c32 100644 --- a/src/transformers/models/detr/__init__.py +++ b/src/transformers/models/detr/__init__.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_timm_available, is_vision_available -_import_structure = {"configuration_detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig"]} +_import_structure = {"configuration_detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig", "DetrOnnxConfig"]} try: if not is_vision_available(): @@ -47,7 +47,7 @@ else: if TYPE_CHECKING: - from .configuration_detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig + from .configuration_detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig, DetrOnnxConfig try: if not is_vision_available(): diff --git a/src/transformers/models/detr/configuration_detr.py b/src/transformers/models/detr/configuration_detr.py index 73f797a7d6..fa8086efc4 100644 --- a/src/transformers/models/detr/configuration_detr.py +++ b/src/transformers/models/detr/configuration_detr.py @@ -14,7 +14,13 @@ # limitations under the License. """ DETR model configuration""" +from collections import OrderedDict +from typing import Mapping + +from packaging import version + from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -204,3 +210,25 @@ class DetrConfig(PretrainedConfig): @property def hidden_size(self) -> int: return self.d_model + + +class DetrOnnxConfig(OnnxConfig): + + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "sequence"}), + ("pixel_mask", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-5 + + @property + def default_onnx_opset(self) -> int: + return 12 diff --git a/src/transformers/onnx/config.py b/src/transformers/onnx/config.py index f97d61ea40..6097ebf49a 100644 --- a/src/transformers/onnx/config.py +++ b/src/transformers/onnx/config.py @@ -77,9 +77,22 @@ class OnnxConfig(ABC): "causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), "default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}), "image-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), + "image-segmentation": OrderedDict( + { + "logits": {0: "batch", 1: "sequence"}, + "pred_boxes": {0: "batch", 1: "sequence"}, + "pred_masks": {0: "batch", 1: "sequence"}, + } + ), "masked-im": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), "masked-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), "multiple-choice": OrderedDict({"logits": {0: "batch"}}), + "object-detection": OrderedDict( + { + "logits": {0: "batch", 1: "sequence"}, + "pred_boxes": {0: "batch", 1: "sequence"}, + } + ), "question-answering": OrderedDict( { "start_logits": {0: "batch", 1: "sequence"}, diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 1d78a0ecd4..9a76cfc012 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -15,9 +15,11 @@ if is_torch_available(): AutoModel, AutoModelForCausalLM, AutoModelForImageClassification, + AutoModelForImageSegmentation, AutoModelForMaskedImageModeling, AutoModelForMaskedLM, AutoModelForMultipleChoice, + AutoModelForObjectDetection, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, @@ -83,8 +85,10 @@ class FeaturesManager: "sequence-classification": AutoModelForSequenceClassification, "token-classification": AutoModelForTokenClassification, "multiple-choice": AutoModelForMultipleChoice, + "object-detection": AutoModelForObjectDetection, "question-answering": AutoModelForQuestionAnswering, "image-classification": AutoModelForImageClassification, + "image-segmentation": AutoModelForImageSegmentation, "masked-im": AutoModelForMaskedImageModeling, } if is_tf_available(): @@ -227,6 +231,12 @@ class FeaturesManager: "deit": supported_features_mapping( "default", "image-classification", "masked-im", onnx_config_cls="models.deit.DeiTOnnxConfig" ), + "detr": supported_features_mapping( + "default", + "object-detection", + "image-segmentation", + onnx_config_cls="models.detr.DetrOnnxConfig", + ), "distilbert": supported_features_mapping( "default", "masked-lm", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 6adc0731b4..f409b36f91 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -183,6 +183,7 @@ PYTORCH_EXPORT_MODELS = { ("deberta", "microsoft/deberta-base"), ("deberta-v2", "microsoft/deberta-v2-xlarge"), ("convnext", "facebook/convnext-tiny-224"), + ("detr", "facebook/detr-resnet-50"), ("distilbert", "distilbert-base-cased"), ("electra", "google/electra-base-generator"), ("resnet", "microsoft/resnet-50"),