From 46d0e26a276f18157223ee0474560bcc78d74920 Mon Sep 17 00:00:00 2001 From: Dhruv Karan Date: Tue, 30 Aug 2022 18:00:59 +0530 Subject: [PATCH] Adds OWLViT to models exportable with ONNX (#18588) * onnx conversion for owlvit * .T to .t() * dynamic shapes for pixel values --- docs/source/en/serialization.mdx | 1 + src/transformers/models/owlvit/__init__.py | 2 + .../models/owlvit/configuration_owlvit.py | 50 ++++++++++++++++++- .../models/owlvit/modeling_owlvit.py | 7 ++- src/transformers/onnx/features.py | 4 ++ tests/onnx/test_onnx_v2.py | 1 + 6 files changed, 62 insertions(+), 3 deletions(-) diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index 89b73df4f5..11336c61a4 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -83,6 +83,7 @@ Ready-made configurations include the following architectures: - MobileViT - MT5 - OpenAI GPT-2 +- OWL-ViT - Perceiver - PLBart - ResNet diff --git a/src/transformers/models/owlvit/__init__.py b/src/transformers/models/owlvit/__init__.py index 8315df69fa..cc528d315e 100644 --- a/src/transformers/models/owlvit/__init__.py +++ b/src/transformers/models/owlvit/__init__.py @@ -32,6 +32,7 @@ _import_structure = { "configuration_owlvit": [ "OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "OwlViTConfig", + "OwlViTOnnxConfig", "OwlViTTextConfig", "OwlViTVisionConfig", ], @@ -66,6 +67,7 @@ if TYPE_CHECKING: from .configuration_owlvit import ( OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, OwlViTConfig, + OwlViTOnnxConfig, OwlViTTextConfig, OwlViTVisionConfig, ) diff --git a/src/transformers/models/owlvit/configuration_owlvit.py b/src/transformers/models/owlvit/configuration_owlvit.py index 85ffdbadbe..ff0bd6e612 100644 --- a/src/transformers/models/owlvit/configuration_owlvit.py +++ b/src/transformers/models/owlvit/configuration_owlvit.py @@ -16,9 +16,16 @@ import copy import os -from typing import Dict, Union +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union + + +if TYPE_CHECKING: + from ...processing_utils import ProcessorMixin + from ...utils import TensorType from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -334,3 +341,44 @@ class OwlViTConfig(PretrainedConfig): output["vision_config"] = self.vision_config.to_dict() output["model_type"] = self.__class__.model_type return output + + +class OwlViTOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("logits_per_image", {0: "batch"}), + ("logits_per_text", {0: "batch"}), + ("text_embeds", {0: "batch"}), + ("image_embeds", {0: "batch"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + def generate_dummy_inputs( + self, + processor: "ProcessorMixin", + framework: Optional["TensorType"] = None, + ) -> Mapping[str, Any]: + + text_input_dict = super().generate_dummy_inputs(processor.tokenizer, framework=framework) + image_input_dict = super().generate_dummy_inputs(processor.feature_extractor, framework=framework) + return {**text_input_dict, **image_input_dict} + + @property + def default_onnx_opset(self) -> int: + return 14 diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index c0386ab23d..5ff22c901a 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -687,7 +687,10 @@ class OwlViTTextTransformer(nn.Module): last_hidden_state = self.final_layer_norm(last_hidden_state) # take features from the end of tokens embedding (end of token is the highest number in each sequence) - pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)] + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1) + ] if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] @@ -1066,7 +1069,7 @@ class OwlViTModel(OwlViTPreTrainedModel): # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale - logits_per_image = logits_per_text.T + logits_per_image = logits_per_text.t() loss = None if return_loss: diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index eb57df1c96..879ba1c262 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -416,6 +416,10 @@ class FeaturesManager: "seq2seq-lm-with-past", onnx_config_cls="models.m2m_100.M2M100OnnxConfig", ), + "owlvit": supported_features_mapping( + "default", + onnx_config_cls="models.owlvit.OwlViTOnnxConfig", + ), "perceiver": supported_features_mapping( "image-classification", "masked-lm", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 52ced984ca..3872f1dfa0 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -205,6 +205,7 @@ PYTORCH_EXPORT_MODELS = { ("layoutlm", "microsoft/layoutlm-base-uncased"), ("layoutlmv3", "microsoft/layoutlmv3-base"), ("levit", "facebook/levit-128S"), + ("owlvit", "google/owlvit-base-patch32"), ("vit", "google/vit-base-patch16-224"), ("deit", "facebook/deit-small-patch16-224"), ("beit", "microsoft/beit-base-patch16-224"),