Adds CLIP to models exportable with ONNX (#18515)
* onnx config for clip * default opset as 14 * changes from the original repo * input values order fix * outputs fix * remove unused import * ran make fix-copies * black format * review comments: forward ref, import fix, model change revert, .to cleanup * make style * formatting fixes * revert groupvit * comment for cast to int32 * comment fix * make .T as .t() for onnx conversion * ran make fix-copies * remove unneeded comment Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix copies * remove comment Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -55,6 +55,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- BlenderbotSmall
|
- BlenderbotSmall
|
||||||
- BLOOM
|
- BLOOM
|
||||||
- CamemBERT
|
- CamemBERT
|
||||||
|
- CLIP
|
||||||
- CodeGen
|
- CodeGen
|
||||||
- ConvBERT
|
- ConvBERT
|
||||||
- ConvNeXT
|
- ConvNeXT
|
||||||
|
|||||||
@@ -29,7 +29,13 @@ from ...utils import (
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_clip": ["CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP", "CLIPConfig", "CLIPTextConfig", "CLIPVisionConfig"],
|
"configuration_clip": [
|
||||||
|
"CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||||
|
"CLIPConfig",
|
||||||
|
"CLIPOnnxConfig",
|
||||||
|
"CLIPTextConfig",
|
||||||
|
"CLIPVisionConfig",
|
||||||
|
],
|
||||||
"tokenization_clip": ["CLIPTokenizer"],
|
"tokenization_clip": ["CLIPTokenizer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,7 +101,13 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_clip import CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
from .configuration_clip import (
|
||||||
|
CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
CLIPConfig,
|
||||||
|
CLIPOnnxConfig,
|
||||||
|
CLIPTextConfig,
|
||||||
|
CLIPVisionConfig,
|
||||||
|
)
|
||||||
from .tokenization_clip import CLIPTokenizer
|
from .tokenization_clip import CLIPTokenizer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -16,9 +16,16 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
from typing import Union
|
from collections import OrderedDict
|
||||||
|
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...processing_utils import ProcessorMixin
|
||||||
|
from ...utils import TensorType
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -317,3 +324,44 @@ class CLIPConfig(PretrainedConfig):
|
|||||||
output["vision_config"] = self.vision_config.to_dict()
|
output["vision_config"] = self.vision_config.to_dict()
|
||||||
output["model_type"] = self.__class__.model_type
|
output["model_type"] = self.__class__.model_type
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
("pixel_values", {0: "batch"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("logits_per_image", {0: "batch"}),
|
||||||
|
("logits_per_text", {0: "batch"}),
|
||||||
|
("text_embeds", {0: "batch"}),
|
||||||
|
("image_embeds", {0: "batch"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def atol_for_validation(self) -> float:
|
||||||
|
return 1e-4
|
||||||
|
|
||||||
|
def generate_dummy_inputs(
|
||||||
|
self,
|
||||||
|
processor: "ProcessorMixin",
|
||||||
|
framework: Optional["TensorType"] = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
|
||||||
|
text_input_dict = super().generate_dummy_inputs(processor.tokenizer, framework=framework)
|
||||||
|
image_input_dict = super().generate_dummy_inputs(processor.feature_extractor, framework=framework)
|
||||||
|
return {**text_input_dict, **image_input_dict}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_onnx_opset(self) -> int:
|
||||||
|
return 14
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
|
|||||||
|
|
||||||
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
|
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
|
||||||
caption_loss = contrastive_loss(similarity)
|
caption_loss = contrastive_loss(similarity)
|
||||||
image_loss = contrastive_loss(similarity.T)
|
image_loss = contrastive_loss(similarity.t())
|
||||||
return (caption_loss + image_loss) / 2.0
|
return (caption_loss + image_loss) / 2.0
|
||||||
|
|
||||||
|
|
||||||
@@ -660,7 +660,10 @@ class CLIPTextTransformer(nn.Module):
|
|||||||
|
|
||||||
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||||
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
|
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||||
|
pooled_output = last_hidden_state[
|
||||||
|
torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
|
||||||
|
]
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||||
@@ -1050,7 +1053,7 @@ class CLIPModel(CLIPPreTrainedModel):
|
|||||||
# cosine similarity as logits
|
# cosine similarity as logits
|
||||||
logit_scale = self.logit_scale.exp()
|
logit_scale = self.logit_scale.exp()
|
||||||
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
||||||
logits_per_image = logits_per_text.T
|
logits_per_image = logits_per_text.t()
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if return_loss:
|
if return_loss:
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
|
|||||||
# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->groupvit
|
# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->groupvit
|
||||||
def groupvit_loss(similarity: torch.Tensor) -> torch.Tensor:
|
def groupvit_loss(similarity: torch.Tensor) -> torch.Tensor:
|
||||||
caption_loss = contrastive_loss(similarity)
|
caption_loss = contrastive_loss(similarity)
|
||||||
image_loss = contrastive_loss(similarity.T)
|
image_loss = contrastive_loss(similarity.t())
|
||||||
return (caption_loss + image_loss) / 2.0
|
return (caption_loss + image_loss) / 2.0
|
||||||
|
|
||||||
|
|
||||||
@@ -1132,7 +1132,10 @@ class GroupViTTextTransformer(nn.Module):
|
|||||||
|
|
||||||
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||||
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
|
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||||
|
pooled_output = last_hidden_state[
|
||||||
|
torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
|
||||||
|
]
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
|
|||||||
# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->owlvit
|
# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->owlvit
|
||||||
def owlvit_loss(similarity: torch.Tensor) -> torch.Tensor:
|
def owlvit_loss(similarity: torch.Tensor) -> torch.Tensor:
|
||||||
caption_loss = contrastive_loss(similarity)
|
caption_loss = contrastive_loss(similarity)
|
||||||
image_loss = contrastive_loss(similarity.T)
|
image_loss = contrastive_loss(similarity.t())
|
||||||
return (caption_loss + image_loss) / 2.0
|
return (caption_loss + image_loss) / 2.0
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -154,7 +154,7 @@ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
|
|||||||
# Copied from transformers.models.clip.modeling_clip.clip_loss
|
# Copied from transformers.models.clip.modeling_clip.clip_loss
|
||||||
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
|
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
|
||||||
caption_loss = contrastive_loss(similarity)
|
caption_loss = contrastive_loss(similarity)
|
||||||
image_loss = contrastive_loss(similarity.T)
|
image_loss = contrastive_loss(similarity.t())
|
||||||
return (caption_loss + image_loss) / 2.0
|
return (caption_loss + image_loss) / 2.0
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -201,6 +201,10 @@ class FeaturesManager:
|
|||||||
"question-answering",
|
"question-answering",
|
||||||
onnx_config_cls="models.camembert.CamembertOnnxConfig",
|
onnx_config_cls="models.camembert.CamembertOnnxConfig",
|
||||||
),
|
),
|
||||||
|
"clip": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
onnx_config_cls="models.clip.CLIPOnnxConfig",
|
||||||
|
),
|
||||||
"codegen": supported_features_mapping(
|
"codegen": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"causal-lm",
|
"causal-lm",
|
||||||
|
|||||||
@@ -185,6 +185,7 @@ PYTORCH_EXPORT_MODELS = {
|
|||||||
("big-bird", "google/bigbird-roberta-base"),
|
("big-bird", "google/bigbird-roberta-base"),
|
||||||
("ibert", "kssteven/ibert-roberta-base"),
|
("ibert", "kssteven/ibert-roberta-base"),
|
||||||
("camembert", "camembert-base"),
|
("camembert", "camembert-base"),
|
||||||
|
("clip", "openai/clip-vit-base-patch32"),
|
||||||
("convbert", "YituTech/conv-bert-base"),
|
("convbert", "YituTech/conv-bert-base"),
|
||||||
("codegen", "Salesforce/codegen-350M-multi"),
|
("codegen", "Salesforce/codegen-350M-multi"),
|
||||||
("deberta", "microsoft/deberta-base"),
|
("deberta", "microsoft/deberta-base"),
|
||||||
|
|||||||
Reference in New Issue
Block a user