Adds OWLViT to models exportable with ONNX (#18588)
* onnx conversion for owlvit * .T to .t() * dynamic shapes for pixel values
This commit is contained in:
@@ -83,6 +83,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- MobileViT
|
- MobileViT
|
||||||
- MT5
|
- MT5
|
||||||
- OpenAI GPT-2
|
- OpenAI GPT-2
|
||||||
|
- OWL-ViT
|
||||||
- Perceiver
|
- Perceiver
|
||||||
- PLBart
|
- PLBart
|
||||||
- ResNet
|
- ResNet
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ _import_structure = {
|
|||||||
"configuration_owlvit": [
|
"configuration_owlvit": [
|
||||||
"OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
"OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||||
"OwlViTConfig",
|
"OwlViTConfig",
|
||||||
|
"OwlViTOnnxConfig",
|
||||||
"OwlViTTextConfig",
|
"OwlViTTextConfig",
|
||||||
"OwlViTVisionConfig",
|
"OwlViTVisionConfig",
|
||||||
],
|
],
|
||||||
@@ -66,6 +67,7 @@ if TYPE_CHECKING:
|
|||||||
from .configuration_owlvit import (
|
from .configuration_owlvit import (
|
||||||
OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
OwlViTConfig,
|
OwlViTConfig,
|
||||||
|
OwlViTOnnxConfig,
|
||||||
OwlViTTextConfig,
|
OwlViTTextConfig,
|
||||||
OwlViTVisionConfig,
|
OwlViTVisionConfig,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,9 +16,16 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
from typing import Dict, Union
|
from collections import OrderedDict
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...processing_utils import ProcessorMixin
|
||||||
|
from ...utils import TensorType
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -334,3 +341,44 @@ class OwlViTConfig(PretrainedConfig):
|
|||||||
output["vision_config"] = self.vision_config.to_dict()
|
output["vision_config"] = self.vision_config.to_dict()
|
||||||
output["model_type"] = self.__class__.model_type
|
output["model_type"] = self.__class__.model_type
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class OwlViTOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("logits_per_image", {0: "batch"}),
|
||||||
|
("logits_per_text", {0: "batch"}),
|
||||||
|
("text_embeds", {0: "batch"}),
|
||||||
|
("image_embeds", {0: "batch"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def atol_for_validation(self) -> float:
|
||||||
|
return 1e-4
|
||||||
|
|
||||||
|
def generate_dummy_inputs(
|
||||||
|
self,
|
||||||
|
processor: "ProcessorMixin",
|
||||||
|
framework: Optional["TensorType"] = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
|
||||||
|
text_input_dict = super().generate_dummy_inputs(processor.tokenizer, framework=framework)
|
||||||
|
image_input_dict = super().generate_dummy_inputs(processor.feature_extractor, framework=framework)
|
||||||
|
return {**text_input_dict, **image_input_dict}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_onnx_opset(self) -> int:
|
||||||
|
return 14
|
||||||
|
|||||||
@@ -687,7 +687,10 @@ class OwlViTTextTransformer(nn.Module):
|
|||||||
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||||||
|
|
||||||
# take features from the end of tokens embedding (end of token is the highest number in each sequence)
|
# take features from the end of tokens embedding (end of token is the highest number in each sequence)
|
||||||
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
|
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||||
|
pooled_output = last_hidden_state[
|
||||||
|
torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
|
||||||
|
]
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||||
@@ -1066,7 +1069,7 @@ class OwlViTModel(OwlViTPreTrainedModel):
|
|||||||
# cosine similarity as logits
|
# cosine similarity as logits
|
||||||
logit_scale = self.logit_scale.exp()
|
logit_scale = self.logit_scale.exp()
|
||||||
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
||||||
logits_per_image = logits_per_text.T
|
logits_per_image = logits_per_text.t()
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if return_loss:
|
if return_loss:
|
||||||
|
|||||||
@@ -416,6 +416,10 @@ class FeaturesManager:
|
|||||||
"seq2seq-lm-with-past",
|
"seq2seq-lm-with-past",
|
||||||
onnx_config_cls="models.m2m_100.M2M100OnnxConfig",
|
onnx_config_cls="models.m2m_100.M2M100OnnxConfig",
|
||||||
),
|
),
|
||||||
|
"owlvit": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
onnx_config_cls="models.owlvit.OwlViTOnnxConfig",
|
||||||
|
),
|
||||||
"perceiver": supported_features_mapping(
|
"perceiver": supported_features_mapping(
|
||||||
"image-classification",
|
"image-classification",
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
|
|||||||
@@ -205,6 +205,7 @@ PYTORCH_EXPORT_MODELS = {
|
|||||||
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
||||||
("layoutlmv3", "microsoft/layoutlmv3-base"),
|
("layoutlmv3", "microsoft/layoutlmv3-base"),
|
||||||
("levit", "facebook/levit-128S"),
|
("levit", "facebook/levit-128S"),
|
||||||
|
("owlvit", "google/owlvit-base-patch32"),
|
||||||
("vit", "google/vit-base-patch16-224"),
|
("vit", "google/vit-base-patch16-224"),
|
||||||
("deit", "facebook/deit-small-patch16-224"),
|
("deit", "facebook/deit-small-patch16-224"),
|
||||||
("beit", "microsoft/beit-base-patch16-224"),
|
("beit", "microsoft/beit-base-patch16-224"),
|
||||||
|
|||||||
Reference in New Issue
Block a user