diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index 11336c61a4..d6bf15df7f 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -70,6 +70,7 @@ Ready-made configurations include the following architectures: - FlauBERT - GPT Neo - GPT-J +- GroupViT - I-BERT - LayoutLM - LayoutLMv3 diff --git a/src/transformers/models/groupvit/__init__.py b/src/transformers/models/groupvit/__init__.py index 8d90205497..3985e9ecff 100644 --- a/src/transformers/models/groupvit/__init__.py +++ b/src/transformers/models/groupvit/__init__.py @@ -24,6 +24,7 @@ _import_structure = { "configuration_groupvit": [ "GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "GroupViTConfig", + "GroupViTOnnxConfig", "GroupViTTextConfig", "GroupViTVisionConfig", ], @@ -47,6 +48,7 @@ if TYPE_CHECKING: from .configuration_groupvit import ( GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, GroupViTConfig, + GroupViTOnnxConfig, GroupViTTextConfig, GroupViTVisionConfig, ) diff --git a/src/transformers/models/groupvit/configuration_groupvit.py b/src/transformers/models/groupvit/configuration_groupvit.py index 8940cf40b9..895c0608b7 100644 --- a/src/transformers/models/groupvit/configuration_groupvit.py +++ b/src/transformers/models/groupvit/configuration_groupvit.py @@ -16,12 +16,19 @@ import copy import os -from typing import Union +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, Mapping, Optional, Union from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging +if TYPE_CHECKING: + from ...processing_utils import ProcessorMixin + from ...utils import TensorType + + logger = logging.get_logger(__name__) GROUPVIT_PRETRAINED_CONFIG_ARCHIVE_MAP = { @@ -343,3 +350,44 @@ class GroupViTConfig(PretrainedConfig): output["vision_config"] = self.vision_config.to_dict() output["model_type"] = self.__class__.model_type return output + + +class GroupViTOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("logits_per_image", {0: "batch"}), + ("logits_per_text", {0: "batch"}), + ("text_embeds", {0: "batch"}), + ("image_embeds", {0: "batch"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + def generate_dummy_inputs( + self, + processor: "ProcessorMixin", + framework: Optional["TensorType"] = None, + ) -> Mapping[str, Any]: + + text_input_dict = super().generate_dummy_inputs(processor.tokenizer, framework=framework) + image_input_dict = super().generate_dummy_inputs(processor.feature_extractor, framework=framework) + return {**text_input_dict, **image_input_dict} + + @property + def default_onnx_opset(self) -> int: + return 14 diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index 9817065ab3..3d2f78e3cf 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -1542,7 +1542,7 @@ class GroupViTModel(GroupViTPreTrainedModel): # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale - logits_per_image = logits_per_text.T + logits_per_image = logits_per_text.t() seg_logits = None if output_segmentation: diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 879ba1c262..b1ea30c1af 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -326,6 +326,10 @@ class FeaturesManager: "sequence-classification", onnx_config_cls="models.gpt_neo.GPTNeoOnnxConfig", ), + "groupvit": supported_features_mapping( + "default", + onnx_config_cls="models.groupvit.GroupViTOnnxConfig", + ), "ibert": supported_features_mapping( "default", "masked-lm", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 3872f1dfa0..16ee78a63c 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -204,6 +204,7 @@ PYTORCH_EXPORT_MODELS = { ("xlm-roberta", "xlm-roberta-base"), ("layoutlm", "microsoft/layoutlm-base-uncased"), ("layoutlmv3", "microsoft/layoutlmv3-base"), + ("groupvit", "nvidia/groupvit-gcc-yfcc"), ("levit", "facebook/levit-128S"), ("owlvit", "google/owlvit-base-patch32"), ("vit", "google/vit-base-patch16-224"),