🚨 Add Blip2ForImageTextRetrieval (#29261)
* add Blip2ForImageTextRetrieval * use one line and remove unnecessary space in tests Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * use value from the config, rather than hardcoded * change order of params in Blip2QFormerModel.forward * update docstring * fix style * update test_inference_opt * move embeddings out of Blip2QFormerModel * remove from_vision_qformer_configs * remove autocast float16 in Blip2QFormerModel * rename fiels into vision_projection,text_projection,use_image_text_matching_head * use CLIPOutput for Blip2ImageTextMatchingModelOutput * remove past_key_values_length from Blip2TextEmbeddings * fix small typo in the CLIPOutput docstring * add Blip2ForImageTextRetrieval to Zero Shot Image Classification mapping * update docstring and add require_torch_fp16 * rollback test_inference_opt * use use_image_text_matching_head=True in convert * skip test_model_get_set_embeddings * fix create_rename_keys error on new itm fields * revert to do scale after dot product between "query" and "key" * fix ValueError on convert script for blip2-opt-2.7b * update org of paths to Salesforce * add is_pipeline_test_to_skip for VisualQuestionAnsweringPipelineTests * [run_slow] blip_2 * removed Blip2ForImageTextRetrieval from IGNORE_NON_AUTO_CONFIGURED * fix docstring of Blip2ImageTextMatchingModelOutput * [run_slow] blip_2 * fix multi-gpu tests * [run_slow] blip_2 * [run_slow] blip_2 --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -88,3 +88,16 @@ If you're interested in submitting a resource to be included here, please feel f
|
|||||||
[[autodoc]] Blip2ForConditionalGeneration
|
[[autodoc]] Blip2ForConditionalGeneration
|
||||||
- forward
|
- forward
|
||||||
- generate
|
- generate
|
||||||
|
|
||||||
|
## Blip2ForImageTextRetrieval
|
||||||
|
|
||||||
|
[[autodoc]] Blip2ForImageTextRetrieval
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## Blip2TextModelWithProjection
|
||||||
|
|
||||||
|
[[autodoc]] Blip2TextModelWithProjection
|
||||||
|
|
||||||
|
## Blip2VisionModelWithProjection
|
||||||
|
|
||||||
|
[[autodoc]] Blip2VisionModelWithProjection
|
||||||
|
|||||||
@@ -1578,10 +1578,13 @@ else:
|
|||||||
_import_structure["models.blip_2"].extend(
|
_import_structure["models.blip_2"].extend(
|
||||||
[
|
[
|
||||||
"Blip2ForConditionalGeneration",
|
"Blip2ForConditionalGeneration",
|
||||||
|
"Blip2ForImageTextRetrieval",
|
||||||
"Blip2Model",
|
"Blip2Model",
|
||||||
"Blip2PreTrainedModel",
|
"Blip2PreTrainedModel",
|
||||||
"Blip2QFormerModel",
|
"Blip2QFormerModel",
|
||||||
|
"Blip2TextModelWithProjection",
|
||||||
"Blip2VisionModel",
|
"Blip2VisionModel",
|
||||||
|
"Blip2VisionModelWithProjection",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.bloom"].extend(
|
_import_structure["models.bloom"].extend(
|
||||||
@@ -6327,10 +6330,13 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from .models.blip_2 import (
|
from .models.blip_2 import (
|
||||||
Blip2ForConditionalGeneration,
|
Blip2ForConditionalGeneration,
|
||||||
|
Blip2ForImageTextRetrieval,
|
||||||
Blip2Model,
|
Blip2Model,
|
||||||
Blip2PreTrainedModel,
|
Blip2PreTrainedModel,
|
||||||
Blip2QFormerModel,
|
Blip2QFormerModel,
|
||||||
|
Blip2TextModelWithProjection,
|
||||||
Blip2VisionModel,
|
Blip2VisionModel,
|
||||||
|
Blip2VisionModelWithProjection,
|
||||||
)
|
)
|
||||||
from .models.bloom import (
|
from .models.bloom import (
|
||||||
BloomForCausalLM,
|
BloomForCausalLM,
|
||||||
|
|||||||
@@ -161,19 +161,19 @@ class AltCLIPOutput(ModelOutput):
|
|||||||
Args:
|
Args:
|
||||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
||||||
Contrastive loss for image-text similarity.
|
Contrastive loss for image-text similarity.
|
||||||
logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
||||||
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
||||||
similarity scores.
|
similarity scores.
|
||||||
logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
||||||
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
||||||
similarity scores.
|
similarity scores.
|
||||||
text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||||
The text embeddings obtained by applying the projection layer to the pooled output of [`AltCLIPTextModel`].
|
The text embeddings obtained by applying the projection layer to the pooled output of [`AltCLIPTextModel`].
|
||||||
image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||||
The image embeddings obtained by applying the projection layer to the pooled output of [`AltCLIPVisionModel`].
|
The image embeddings obtained by applying the projection layer to the pooled output of [`AltCLIPVisionModel`].
|
||||||
text_model_output(`BaseModelOutputWithPooling`):
|
text_model_output (`BaseModelOutputWithPooling`):
|
||||||
The output of the [`AltCLIPTextModel`].
|
The output of the [`AltCLIPTextModel`].
|
||||||
vision_model_output(`BaseModelOutputWithPooling`):
|
vision_model_output (`BaseModelOutputWithPooling`):
|
||||||
The output of the [`AltCLIPVisionModel`].
|
The output of the [`AltCLIPVisionModel`].
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -1266,6 +1266,7 @@ MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
("align", "AlignModel"),
|
("align", "AlignModel"),
|
||||||
("altclip", "AltCLIPModel"),
|
("altclip", "AltCLIPModel"),
|
||||||
("blip", "BlipModel"),
|
("blip", "BlipModel"),
|
||||||
|
("blip-2", "Blip2ForImageTextRetrieval"),
|
||||||
("chinese_clip", "ChineseCLIPModel"),
|
("chinese_clip", "ChineseCLIPModel"),
|
||||||
("clip", "CLIPModel"),
|
("clip", "CLIPModel"),
|
||||||
("clipseg", "CLIPSegModel"),
|
("clipseg", "CLIPSegModel"),
|
||||||
|
|||||||
@@ -33,10 +33,13 @@ except OptionalDependencyNotAvailable:
|
|||||||
else:
|
else:
|
||||||
_import_structure["modeling_blip_2"] = [
|
_import_structure["modeling_blip_2"] = [
|
||||||
"Blip2Model",
|
"Blip2Model",
|
||||||
|
"Blip2VisionModelWithProjection",
|
||||||
"Blip2QFormerModel",
|
"Blip2QFormerModel",
|
||||||
"Blip2PreTrainedModel",
|
"Blip2PreTrainedModel",
|
||||||
"Blip2ForConditionalGeneration",
|
"Blip2ForConditionalGeneration",
|
||||||
|
"Blip2ForImageTextRetrieval",
|
||||||
"Blip2VisionModel",
|
"Blip2VisionModel",
|
||||||
|
"Blip2TextModelWithProjection",
|
||||||
]
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -55,10 +58,13 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
from .modeling_blip_2 import (
|
from .modeling_blip_2 import (
|
||||||
Blip2ForConditionalGeneration,
|
Blip2ForConditionalGeneration,
|
||||||
|
Blip2ForImageTextRetrieval,
|
||||||
Blip2Model,
|
Blip2Model,
|
||||||
Blip2PreTrainedModel,
|
Blip2PreTrainedModel,
|
||||||
Blip2QFormerModel,
|
Blip2QFormerModel,
|
||||||
|
Blip2TextModelWithProjection,
|
||||||
Blip2VisionModel,
|
Blip2VisionModel,
|
||||||
|
Blip2VisionModelWithProjection,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
"""BLIP-2 model configuration"""
|
"""BLIP-2 model configuration"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||||
@@ -172,6 +172,8 @@ class Blip2QFormerConfig(PretrainedConfig):
|
|||||||
The frequency of adding cross-attention to the Transformer layers.
|
The frequency of adding cross-attention to the Transformer layers.
|
||||||
encoder_hidden_size (`int`, *optional*, defaults to 1408):
|
encoder_hidden_size (`int`, *optional*, defaults to 1408):
|
||||||
The hidden size of the hidden states for cross-attention.
|
The hidden size of the hidden states for cross-attention.
|
||||||
|
use_qformer_text_input (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use BERT-style embeddings.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@@ -206,6 +208,7 @@ class Blip2QFormerConfig(PretrainedConfig):
|
|||||||
position_embedding_type="absolute",
|
position_embedding_type="absolute",
|
||||||
cross_attention_frequency=2,
|
cross_attention_frequency=2,
|
||||||
encoder_hidden_size=1408,
|
encoder_hidden_size=1408,
|
||||||
|
use_qformer_text_input=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||||
@@ -224,6 +227,7 @@ class Blip2QFormerConfig(PretrainedConfig):
|
|||||||
self.position_embedding_type = position_embedding_type
|
self.position_embedding_type = position_embedding_type
|
||||||
self.cross_attention_frequency = cross_attention_frequency
|
self.cross_attention_frequency = cross_attention_frequency
|
||||||
self.encoder_hidden_size = encoder_hidden_size
|
self.encoder_hidden_size = encoder_hidden_size
|
||||||
|
self.use_qformer_text_input = use_qformer_text_input
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
||||||
@@ -263,6 +267,8 @@ class Blip2Config(PretrainedConfig):
|
|||||||
Dictionary of configuration options used to initialize any [`PretrainedConfig`].
|
Dictionary of configuration options used to initialize any [`PretrainedConfig`].
|
||||||
num_query_tokens (`int`, *optional*, defaults to 32):
|
num_query_tokens (`int`, *optional*, defaults to 32):
|
||||||
The number of query tokens passed through the Transformer.
|
The number of query tokens passed through the Transformer.
|
||||||
|
image_text_hidden_size (`int`, *optional*, defaults to 256):
|
||||||
|
Dimentionality of the hidden state of the image-text fusion layer.
|
||||||
|
|
||||||
image_token_index (`int`, *optional*):
|
image_token_index (`int`, *optional*):
|
||||||
Token index of special image token.
|
Token index of special image token.
|
||||||
@@ -307,6 +313,7 @@ class Blip2Config(PretrainedConfig):
|
|||||||
qformer_config=None,
|
qformer_config=None,
|
||||||
text_config=None,
|
text_config=None,
|
||||||
num_query_tokens=32,
|
num_query_tokens=32,
|
||||||
|
image_text_hidden_size=256,
|
||||||
image_token_index=None,
|
image_token_index=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -333,6 +340,7 @@ class Blip2Config(PretrainedConfig):
|
|||||||
self.is_encoder_decoder = self.text_config.is_encoder_decoder
|
self.is_encoder_decoder = self.text_config.is_encoder_decoder
|
||||||
|
|
||||||
self.num_query_tokens = num_query_tokens
|
self.num_query_tokens = num_query_tokens
|
||||||
|
self.image_text_hidden_size = image_text_hidden_size
|
||||||
self.image_token_index = image_token_index
|
self.image_token_index = image_token_index
|
||||||
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
|
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
|
||||||
self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||||
@@ -344,13 +352,21 @@ class Blip2Config(PretrainedConfig):
|
|||||||
cls,
|
cls,
|
||||||
vision_config: Blip2VisionConfig,
|
vision_config: Blip2VisionConfig,
|
||||||
qformer_config: Blip2QFormerConfig,
|
qformer_config: Blip2QFormerConfig,
|
||||||
text_config: PretrainedConfig,
|
text_config: Optional[PretrainedConfig] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Instantiate a [`Blip2Config`] (or a derived class) from a BLIP-2 vision model, Q-Former and language model
|
Instantiate a [`Blip2Config`] (or a derived class) from a BLIP-2 vision model, Q-Former and language model
|
||||||
configurations.
|
configurations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vision_config (`dict`):
|
||||||
|
Dictionary of configuration options used to initialize [`Blip2VisionConfig`].
|
||||||
|
qformer_config (`dict`):
|
||||||
|
Dictionary of configuration options used to initialize [`Blip2QFormerConfig`].
|
||||||
|
text_config (`dict`, *optional*):
|
||||||
|
Dictionary of configuration options used to initialize any [`PretrainedConfig`].
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`Blip2Config`]: An instance of a configuration object
|
[`Blip2Config`]: An instance of a configuration object
|
||||||
"""
|
"""
|
||||||
@@ -358,6 +374,6 @@ class Blip2Config(PretrainedConfig):
|
|||||||
return cls(
|
return cls(
|
||||||
vision_config=vision_config.to_dict(),
|
vision_config=vision_config.to_dict(),
|
||||||
qformer_config=qformer_config.to_dict(),
|
qformer_config=qformer_config.to_dict(),
|
||||||
text_config=text_config.to_dict(),
|
text_config=text_config.to_dict() if text_config is not None else None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -31,9 +31,12 @@ from PIL import Image
|
|||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
BertTokenizer,
|
||||||
Blip2Config,
|
Blip2Config,
|
||||||
Blip2ForConditionalGeneration,
|
Blip2ForConditionalGeneration,
|
||||||
|
Blip2ForImageTextRetrieval,
|
||||||
Blip2Processor,
|
Blip2Processor,
|
||||||
|
Blip2QFormerConfig,
|
||||||
Blip2VisionConfig,
|
Blip2VisionConfig,
|
||||||
BlipImageProcessor,
|
BlipImageProcessor,
|
||||||
OPTConfig,
|
OPTConfig,
|
||||||
@@ -51,7 +54,7 @@ def load_demo_image():
|
|||||||
|
|
||||||
|
|
||||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||||
def create_rename_keys(config):
|
def create_rename_keys(config, model_name):
|
||||||
rename_keys = []
|
rename_keys = []
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
|
||||||
@@ -79,6 +82,13 @@ def create_rename_keys(config):
|
|||||||
# QFormer
|
# QFormer
|
||||||
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight"))
|
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight"))
|
||||||
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias"))
|
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias"))
|
||||||
|
if "itm" in model_name:
|
||||||
|
rename_keys.append(("Qformer.bert.embeddings.word_embeddings.weight", "embeddings.word_embeddings.weight"))
|
||||||
|
rename_keys.append(("Qformer.bert.embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"))
|
||||||
|
rename_keys.append(("vision_proj.weight", "vision_projection.weight"))
|
||||||
|
rename_keys.append(("vision_proj.bias", "vision_projection.bias"))
|
||||||
|
rename_keys.append(("text_proj.weight", "text_projection.weight"))
|
||||||
|
rename_keys.append(("text_proj.bias", "text_projection.bias"))
|
||||||
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
return rename_keys
|
return rename_keys
|
||||||
@@ -114,25 +124,46 @@ def get_blip2_config(model_name, eos_token_id):
|
|||||||
text_config = T5Config.from_pretrained("google/flan-t5-xl", dense_act_fn="gelu", bos_token_id=1).to_dict()
|
text_config = T5Config.from_pretrained("google/flan-t5-xl", dense_act_fn="gelu", bos_token_id=1).to_dict()
|
||||||
elif "t5-xxl" in model_name:
|
elif "t5-xxl" in model_name:
|
||||||
text_config = T5Config.from_pretrained("google/flan-t5-xxl", dense_act_fn="gelu", bos_token_id=1).to_dict()
|
text_config = T5Config.from_pretrained("google/flan-t5-xxl", dense_act_fn="gelu", bos_token_id=1).to_dict()
|
||||||
|
elif "itm" in model_name:
|
||||||
|
text_config = {}
|
||||||
|
else:
|
||||||
|
raise ValueError("Model name not supported")
|
||||||
|
|
||||||
|
if "itm" in model_name:
|
||||||
|
config = Blip2Config(
|
||||||
|
vision_config=vision_config,
|
||||||
|
qformer_config=Blip2QFormerConfig(vocab_size=30523, use_qformer_text_input=True).to_dict(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
config = Blip2Config(vision_config=vision_config, text_config=text_config)
|
config = Blip2Config(vision_config=vision_config, text_config=text_config)
|
||||||
|
|
||||||
return config, image_size
|
return config, image_size
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False):
|
def convert_blip2_checkpoint(
|
||||||
|
model_name, pytorch_dump_folder_path=None, push_to_hub=False, lavis_device="cpu", hf_model_device="cpu"
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Copy/paste/tweak model's weights to Transformers design.
|
Copy/paste/tweak model's weights to Transformers design.
|
||||||
"""
|
"""
|
||||||
tokenizer = (
|
if "opt" in model_name:
|
||||||
AutoTokenizer.from_pretrained("facebook/opt-2.7b")
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b")
|
||||||
if "opt" in model_name
|
elif "itm" in model_name:
|
||||||
else AutoTokenizer.from_pretrained("google/flan-t5-xl")
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right")
|
||||||
)
|
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
||||||
|
else:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
|
||||||
|
|
||||||
|
if "itm" in model_name:
|
||||||
|
eos_token_id = None
|
||||||
|
else:
|
||||||
eos_token_id = tokenizer("\n", add_special_tokens=False).input_ids[0]
|
eos_token_id = tokenizer("\n", add_special_tokens=False).input_ids[0]
|
||||||
config, image_size = get_blip2_config(model_name, eos_token_id=eos_token_id)
|
config, image_size = get_blip2_config(model_name, eos_token_id=eos_token_id)
|
||||||
|
|
||||||
|
if "itm" in model_name:
|
||||||
|
hf_model = Blip2ForImageTextRetrieval(config).eval()
|
||||||
|
else:
|
||||||
hf_model = Blip2ForConditionalGeneration(config).eval()
|
hf_model = Blip2ForConditionalGeneration(config).eval()
|
||||||
|
|
||||||
model_name_to_original = {
|
model_name_to_original = {
|
||||||
@@ -143,16 +174,12 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_
|
|||||||
"blip2-flan-t5-xl": ("blip2_t5", "pretrain_flant5xl"),
|
"blip2-flan-t5-xl": ("blip2_t5", "pretrain_flant5xl"),
|
||||||
"blip2-flan-t5-xl-coco": ("blip2_t5", "caption_coco_flant5xl"),
|
"blip2-flan-t5-xl-coco": ("blip2_t5", "caption_coco_flant5xl"),
|
||||||
"blip2-flan-t5-xxl": ("blip2_t5", "pretrain_flant5xxl"),
|
"blip2-flan-t5-xxl": ("blip2_t5", "pretrain_flant5xxl"),
|
||||||
|
"blip2-itm-vit-g": ("blip2_image_text_matching", "pretrain"),
|
||||||
|
"blip2-itm-vit-g-coco": ("blip2_image_text_matching", "coco"),
|
||||||
}
|
}
|
||||||
|
|
||||||
name, type = model_name_to_original[model_name]
|
name, type = model_name_to_original[model_name]
|
||||||
|
|
||||||
# note: this script is tested on 2 GPUs, as models are compared in float32,
|
|
||||||
# which requires quite some memory. Hence loading both on a
|
|
||||||
# separate device is the easiest to compare
|
|
||||||
hf_model_device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
||||||
lavis_device = "cuda:1" if torch.cuda.is_available() else "cpu"
|
|
||||||
|
|
||||||
# load original model
|
# load original model
|
||||||
print("Loading original model...")
|
print("Loading original model...")
|
||||||
original_model, vis_processors, _ = load_model_and_preprocess(
|
original_model, vis_processors, _ = load_model_and_preprocess(
|
||||||
@@ -163,7 +190,7 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_
|
|||||||
|
|
||||||
# update state dict keys
|
# update state dict keys
|
||||||
state_dict = original_model.state_dict()
|
state_dict = original_model.state_dict()
|
||||||
rename_keys = create_rename_keys(config)
|
rename_keys = create_rename_keys(config, model_name)
|
||||||
for src, dest in rename_keys:
|
for src, dest in rename_keys:
|
||||||
rename_key(state_dict, src, dest)
|
rename_key(state_dict, src, dest)
|
||||||
|
|
||||||
@@ -189,11 +216,15 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_
|
|||||||
|
|
||||||
missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False)
|
missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False)
|
||||||
assert len(missing_keys) == 0
|
assert len(missing_keys) == 0
|
||||||
|
|
||||||
|
if "itm" in model_name:
|
||||||
|
unexpected_keys = list(filter(lambda x: not x.startswith("Qformer.cls"), unexpected_keys))
|
||||||
|
assert unexpected_keys == ["temp", "qformer.embeddings.position_ids"]
|
||||||
|
else:
|
||||||
assert unexpected_keys == ["qformer.embeddings.position_ids"]
|
assert unexpected_keys == ["qformer.embeddings.position_ids"]
|
||||||
|
|
||||||
image = load_demo_image()
|
image = load_demo_image()
|
||||||
original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device)
|
original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device)
|
||||||
input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device)
|
|
||||||
|
|
||||||
# create processor
|
# create processor
|
||||||
image_processor = BlipImageProcessor(
|
image_processor = BlipImageProcessor(
|
||||||
@@ -207,6 +238,61 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_
|
|||||||
|
|
||||||
original_model.to(lavis_device)
|
original_model.to(lavis_device)
|
||||||
hf_model.to(hf_model_device)
|
hf_model.to(hf_model_device)
|
||||||
|
|
||||||
|
if "itm" in model_name:
|
||||||
|
caption = "a large fountain spewing water into the air"
|
||||||
|
input_ids = tokenizer([caption], return_tensors="pt").input_ids.to(hf_model_device)
|
||||||
|
attention_mask = processor(text=caption, return_tensors="pt").attention_mask.to(hf_model_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
original_logits = original_model(
|
||||||
|
{"image": original_pixel_values, "text_input": [caption]}, match_head="itm"
|
||||||
|
)
|
||||||
|
logits = hf_model(
|
||||||
|
pixel_values=original_pixel_values,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
use_image_text_matching_head=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert original_logits.shape == logits.logits_per_image.shape
|
||||||
|
print("First values of original logits:", original_logits[0, :3])
|
||||||
|
print("First values of HF logits:", logits.logits_per_image[0, :3])
|
||||||
|
|
||||||
|
# assert values
|
||||||
|
# cast to same type
|
||||||
|
target_dtype = logits.logits_per_image.dtype
|
||||||
|
assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
|
||||||
|
|
||||||
|
original_itm_scores = torch.nn.functional.softmax(original_logits, dim=1)
|
||||||
|
itm_scores = torch.nn.functional.softmax(logits.logits_per_image, dim=1)
|
||||||
|
assert torch.allclose(original_itm_scores.to(target_dtype), itm_scores, atol=1e-4)
|
||||||
|
print("Looks ok!")
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
original_logits = original_model(
|
||||||
|
{"image": original_pixel_values, "text_input": [caption]}, match_head="itc"
|
||||||
|
)
|
||||||
|
logits = hf_model(
|
||||||
|
pixel_values=original_pixel_values,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
use_image_text_matching_head=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert original_logits.shape == logits.logits_per_image.shape
|
||||||
|
print("First values of original logits:", original_logits[0, :3])
|
||||||
|
print("First values of HF logits:", logits.logits_per_image[0, :3])
|
||||||
|
|
||||||
|
# assert values
|
||||||
|
# cast to same type
|
||||||
|
target_dtype = logits.logits_per_image.dtype
|
||||||
|
assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
|
||||||
|
print("Looks ok!")
|
||||||
|
|
||||||
|
else:
|
||||||
|
input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "opt" in model_name:
|
if "opt" in model_name:
|
||||||
original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits
|
original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits
|
||||||
@@ -233,7 +319,7 @@ def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_
|
|||||||
set_seed(42)
|
set_seed(42)
|
||||||
|
|
||||||
original_outputs = original_model.generate(
|
original_outputs = original_model.generate(
|
||||||
{"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True
|
{"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True, max_length=50
|
||||||
)
|
)
|
||||||
outputs = hf_model.generate(
|
outputs = hf_model.generate(
|
||||||
pixel_values,
|
pixel_values,
|
||||||
@@ -271,6 +357,8 @@ if __name__ == "__main__":
|
|||||||
"blip2-flan-t5-xl",
|
"blip2-flan-t5-xl",
|
||||||
"blip2-flan-t5-xl-coco",
|
"blip2-flan-t5-xl-coco",
|
||||||
"blip2-flan-t5-xxl",
|
"blip2-flan-t5-xxl",
|
||||||
|
"blip2-itm-vit-g",
|
||||||
|
"blip2-itm-vit-g-coco",
|
||||||
]
|
]
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_name",
|
"--model_name",
|
||||||
@@ -285,7 +373,18 @@ if __name__ == "__main__":
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Whether to push the model and processor to the hub after converting",
|
help="Whether to push the model and processor to the hub after converting",
|
||||||
)
|
)
|
||||||
|
# note: this script is tested on 2 GPUs, as models are compared in float32,
|
||||||
|
# which requires quite some memory. Hence loading both on a
|
||||||
|
# separate device is the easiest to compare
|
||||||
|
parser.add_argument(
|
||||||
|
"--lavis_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf_model_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
convert_blip2_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
convert_blip2_checkpoint(
|
||||||
|
args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.lavis_device, args.hf_model_device
|
||||||
|
)
|
||||||
|
|||||||
@@ -81,6 +81,103 @@ class Blip2ForConditionalGenerationModelOutput(ModelOutput):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Blip2ImageTextMatchingModelOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
||||||
|
Contrastive loss for image-text similarity.
|
||||||
|
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
||||||
|
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
||||||
|
similarity scores.
|
||||||
|
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
||||||
|
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
||||||
|
similarity scores.
|
||||||
|
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||||
|
The text embeddings obtained by applying the projection layer to the pooled output.
|
||||||
|
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||||
|
The image embeddings obtained by applying the projection layer to the pooled output.
|
||||||
|
text_model_output (`BaseModelOutputWithPooling`):
|
||||||
|
The output of the [`Blip2QFormerModel`].
|
||||||
|
vision_model_output (`BaseModelOutputWithPooling`):
|
||||||
|
The output of the [`Blip2VisionModel`].
|
||||||
|
"""
|
||||||
|
|
||||||
|
loss: Optional[torch.FloatTensor] = None
|
||||||
|
logits_per_image: torch.FloatTensor = None
|
||||||
|
logits_per_text: torch.FloatTensor = None
|
||||||
|
text_embeds: torch.FloatTensor = None
|
||||||
|
image_embeds: torch.FloatTensor = None
|
||||||
|
text_model_output: BaseModelOutputWithPooling = None
|
||||||
|
vision_model_output: BaseModelOutputWithPooling = None
|
||||||
|
|
||||||
|
def to_tuple(self) -> Tuple[Any]:
|
||||||
|
return tuple(
|
||||||
|
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
|
||||||
|
for k in self.keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Blip2
|
||||||
|
class Blip2TextModelOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for text model's outputs that also contains a pooling of the last hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
||||||
|
The text embeddings obtained by applying the projection layer to the pooler_output.
|
||||||
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the model.
|
||||||
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||||
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||||
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||||
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||||
|
sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
text_embeds: Optional[torch.FloatTensor] = None
|
||||||
|
last_hidden_state: torch.FloatTensor = None
|
||||||
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||||
|
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Blip2
|
||||||
|
class Blip2VisionModelOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
||||||
|
The image embeddings obtained by applying the projection layer to the pooler_output.
|
||||||
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the model.
|
||||||
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||||
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||||
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||||
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||||
|
sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_embeds: Optional[torch.FloatTensor] = None
|
||||||
|
last_hidden_state: torch.FloatTensor = None
|
||||||
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||||
|
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2
|
# Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2
|
||||||
class Blip2VisionEmbeddings(nn.Module):
|
class Blip2VisionEmbeddings(nn.Module):
|
||||||
def __init__(self, config: Blip2VisionConfig):
|
def __init__(self, config: Blip2VisionConfig):
|
||||||
@@ -304,7 +401,13 @@ class Blip2PreTrainedModel(PreTrainedModel):
|
|||||||
config_class = Blip2Config
|
config_class = Blip2Config
|
||||||
base_model_prefix = "blip"
|
base_model_prefix = "blip"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["Blip2Attention", "T5Block", "OPTDecoderLayer"]
|
_no_split_modules = [
|
||||||
|
"Blip2Attention",
|
||||||
|
"Blip2QFormerMultiHeadAttention",
|
||||||
|
"Blip2TextEmbeddings",
|
||||||
|
"T5Block",
|
||||||
|
"OPTDecoderLayer",
|
||||||
|
]
|
||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_keep_in_fp32_modules = ["wo"]
|
_keep_in_fp32_modules = ["wo"]
|
||||||
|
|
||||||
@@ -398,6 +501,30 @@ BLIP_2_TEXT_INPUTS_DOCSTRING = r"""
|
|||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
BLIP_2_TEXT_WITH_PROJECTION_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||||
|
it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||||
|
[`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
|
||||||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||||
|
config.max_position_embeddings - 1]`.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||||
|
tensors for more detail.
|
||||||
|
output_hidden_states (`bool`, *optional*):
|
||||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
|
more detail.
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
|
||||||
BLIP_2_INPUTS_DOCSTRING = r"""
|
BLIP_2_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||||
@@ -444,6 +571,43 @@ BLIP_2_INPUTS_DOCSTRING = r"""
|
|||||||
Whether to interpolate the pre-trained position encodings.
|
Whether to interpolate the pre-trained position encodings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
BLIP2_IMAGE_TEXT_RETRIEVAL_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||||
|
Pixel values. Pixel values can be obtained using [`Blip2Processor`]. See [`Blip2Processor.__call__`] for
|
||||||
|
details.
|
||||||
|
|
||||||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
|
||||||
|
provided to serve as text prompt, which the language model can continue.
|
||||||
|
|
||||||
|
Indices can be obtained using [`Blip2Processor`]. See [`Blip2Processor.__call__`] for details.
|
||||||
|
|
||||||
|
[What are input IDs?](../glossary#input-ids)
|
||||||
|
|
||||||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
|
||||||
|
use_image_text_matching_head (`bool`, *optional*):
|
||||||
|
Whether to return the Image-Text Matching or Contrastive scores.
|
||||||
|
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||||
|
tensors for more detail.
|
||||||
|
|
||||||
|
output_hidden_states (`bool`, *optional*):
|
||||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
|
more detail.
|
||||||
|
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Blip2
|
# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Blip2
|
||||||
class Blip2Encoder(nn.Module):
|
class Blip2Encoder(nn.Module):
|
||||||
@@ -842,6 +1006,10 @@ class Blip2QFormerLayer(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.has_cross_attention = False
|
self.has_cross_attention = False
|
||||||
|
|
||||||
|
if config.use_qformer_text_input:
|
||||||
|
self.intermediate = Blip2QFormerIntermediate(config)
|
||||||
|
self.output = Blip2QFormerOutput(config)
|
||||||
|
|
||||||
self.intermediate_query = Blip2QFormerIntermediate(config)
|
self.intermediate_query = Blip2QFormerIntermediate(config)
|
||||||
self.output_query = Blip2QFormerOutput(config)
|
self.output_query = Blip2QFormerOutput(config)
|
||||||
|
|
||||||
@@ -1022,6 +1190,49 @@ class Blip2QFormerEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Blip2TextEmbeddings(nn.Module):
|
||||||
|
"""Construct the embeddings from word and position embeddings."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||||||
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||||
|
|
||||||
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
|
self.register_buffer(
|
||||||
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
||||||
|
)
|
||||||
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.FloatTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
query_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if input_ids is not None:
|
||||||
|
seq_length = input_ids.size()[1]
|
||||||
|
else:
|
||||||
|
seq_length = 0
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = self.position_ids[:, :seq_length]
|
||||||
|
|
||||||
|
if input_ids is not None:
|
||||||
|
input_ids = input_ids.to(self.word_embeddings.weight.device)
|
||||||
|
embeddings = self.word_embeddings(input_ids)
|
||||||
|
if self.position_embedding_type == "absolute":
|
||||||
|
position_embeddings = self.position_embeddings(position_ids)
|
||||||
|
embeddings += position_embeddings
|
||||||
|
|
||||||
|
if query_embeds is not None:
|
||||||
|
embeddings = torch.cat((query_embeds, embeddings), dim=1)
|
||||||
|
else:
|
||||||
|
embeddings = query_embeds
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
class Blip2QFormerModel(Blip2PreTrainedModel):
|
class Blip2QFormerModel(Blip2PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
Querying Transformer (Q-Former), used in BLIP-2.
|
Querying Transformer (Q-Former), used in BLIP-2.
|
||||||
@@ -1100,6 +1311,7 @@ class Blip2QFormerModel(Blip2PreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
query_embeds: torch.FloatTensor,
|
query_embeds: torch.FloatTensor,
|
||||||
|
query_length: Optional[int] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
@@ -1140,7 +1352,9 @@ class Blip2QFormerModel(Blip2PreTrainedModel):
|
|||||||
past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
|
past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
|
||||||
)
|
)
|
||||||
|
|
||||||
query_length = query_embeds.shape[1] if query_embeds is not None else 0
|
query_length = (
|
||||||
|
query_length if query_length is not None else query_embeds.shape[1] if query_embeds is not None else 0
|
||||||
|
)
|
||||||
|
|
||||||
embedding_output = self.layernorm(query_embeds)
|
embedding_output = self.layernorm(query_embeds)
|
||||||
embedding_output = self.dropout(embedding_output)
|
embedding_output = self.dropout(embedding_output)
|
||||||
@@ -1567,6 +1781,206 @@ class Blip2Model(Blip2PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
BLIP-2 Text Model with a projection layer on top (a linear layer on top of the pooled output).
|
||||||
|
""",
|
||||||
|
BLIP_2_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class Blip2TextModelWithProjection(Blip2PreTrainedModel):
|
||||||
|
supports_gradient_checkpointing = False
|
||||||
|
_keep_in_fp32_modules = []
|
||||||
|
|
||||||
|
def __init__(self, config: Blip2Config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
||||||
|
self.embeddings = Blip2TextEmbeddings(config.qformer_config)
|
||||||
|
self.qformer = Blip2QFormerModel(config.qformer_config)
|
||||||
|
|
||||||
|
# text projection layer
|
||||||
|
self.text_projection = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(BLIP_2_TEXT_WITH_PROJECTION_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=Blip2TextModelOutput, config_class=Blip2Config)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, Blip2TextModelOutput]:
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from transformers import AutoProcessor, Blip2TextModelWithProjection
|
||||||
|
|
||||||
|
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
>>> model = Blip2TextModelWithProjection.from_pretrained(
|
||||||
|
... "Salesforce/blip2-itm-vit-g", torch_dtype=torch.float16
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> model.to(device) # doctest: +IGNORE_RESULT
|
||||||
|
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip2-itm-vit-g")
|
||||||
|
|
||||||
|
>>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], return_tensors="pt").to(device)
|
||||||
|
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
>>> text_embeds = outputs.text_embeds
|
||||||
|
>>> print(text_embeds.shape)
|
||||||
|
torch.Size([2, 7, 256])
|
||||||
|
```"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
query_embeds = self.embeddings(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
text_outputs = self.qformer(
|
||||||
|
query_embeds=query_embeds,
|
||||||
|
query_length=0,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
pooled_output = text_outputs[0] if not return_dict else text_outputs.last_hidden_state
|
||||||
|
|
||||||
|
text_embeds = self.text_projection(pooled_output)
|
||||||
|
text_embeds = nn.functional.normalize(text_embeds, dim=-1)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
|
||||||
|
return tuple(output for output in outputs if output is not None)
|
||||||
|
|
||||||
|
return Blip2TextModelOutput(
|
||||||
|
text_embeds=text_embeds,
|
||||||
|
last_hidden_state=text_outputs.last_hidden_state,
|
||||||
|
hidden_states=text_outputs.hidden_states,
|
||||||
|
attentions=text_outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
BLIP-2 Vision Model with a projection layer on top (a linear layer on top of the pooled output).
|
||||||
|
""",
|
||||||
|
BLIP_2_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class Blip2VisionModelWithProjection(Blip2PreTrainedModel):
|
||||||
|
main_input_name = "pixel_values"
|
||||||
|
_keep_in_fp32_modules = []
|
||||||
|
|
||||||
|
def __init__(self, config: Blip2Config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.vision_model = Blip2VisionModel(config.vision_config)
|
||||||
|
|
||||||
|
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
||||||
|
self.qformer = Blip2QFormerModel(config.qformer_config)
|
||||||
|
|
||||||
|
# vision projection layer
|
||||||
|
self.vision_projection = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self) -> nn.Module:
|
||||||
|
return self.vision_model.embeddings.patch_embedding
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(BLIP_2_VISION_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=Blip2VisionModelOutput, config_class=Blip2Config)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, Blip2VisionModelOutput]:
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, Blip2VisionModelWithProjection
|
||||||
|
|
||||||
|
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip2-itm-vit-g")
|
||||||
|
>>> model = Blip2VisionModelWithProjection.from_pretrained(
|
||||||
|
... "Salesforce/blip2-itm-vit-g", torch_dtype=torch.float16
|
||||||
|
... )
|
||||||
|
>>> model.to(device) # doctest: +IGNORE_RESULT
|
||||||
|
|
||||||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
|
||||||
|
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
>>> image_embeds = outputs.image_embeds
|
||||||
|
>>> print(image_embeds.shape)
|
||||||
|
torch.Size([1, 32, 256])
|
||||||
|
```"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
vision_outputs = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
pooled_output = vision_outputs[0] if not return_dict else vision_outputs.last_hidden_state
|
||||||
|
|
||||||
|
image_attention_mask = torch.ones(pooled_output.size()[:-1], dtype=torch.long, device=pooled_output.device)
|
||||||
|
|
||||||
|
query_tokens = self.query_tokens.expand(pooled_output.shape[0], -1, -1)
|
||||||
|
|
||||||
|
query_outputs = self.qformer(
|
||||||
|
query_embeds=query_tokens,
|
||||||
|
encoder_hidden_states=pooled_output,
|
||||||
|
encoder_attention_mask=image_attention_mask,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeds = query_outputs[0] if not return_dict else query_outputs.last_hidden_state
|
||||||
|
image_embeds = self.vision_projection(embeds)
|
||||||
|
image_embeds = nn.functional.normalize(image_embeds, dim=-1)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:]
|
||||||
|
return tuple(output for output in outputs if output is not None)
|
||||||
|
|
||||||
|
return Blip2VisionModelOutput(
|
||||||
|
image_embeds=image_embeds,
|
||||||
|
last_hidden_state=vision_outputs.last_hidden_state,
|
||||||
|
hidden_states=vision_outputs.hidden_states,
|
||||||
|
attentions=vision_outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
BLIP-2 Model for generating text given an image and an optional text prompt. The model consists of a vision
|
BLIP-2 Model for generating text given an image and an optional text prompt. The model consists of a vision
|
||||||
@@ -1937,3 +2351,180 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
outputs = torch.cat([bos_tokens, outputs], dim=-1)
|
outputs = torch.cat([bos_tokens, outputs], dim=-1)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
BLIP-2 Model with a vision and text projector, and a classification head on top. The model is used in the context
|
||||||
|
of image-text retrieval. Given an image and a text, the model returns the probability of the text being relevant to
|
||||||
|
the image.
|
||||||
|
""",
|
||||||
|
BLIP_2_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class Blip2ForImageTextRetrieval(Blip2PreTrainedModel):
|
||||||
|
main_input_name = "pixel_values"
|
||||||
|
_keep_in_fp32_modules = []
|
||||||
|
|
||||||
|
def __init__(self, config: Blip2Config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.vision_model = Blip2VisionModel(config.vision_config)
|
||||||
|
|
||||||
|
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
||||||
|
|
||||||
|
self.embeddings = Blip2TextEmbeddings(config.qformer_config)
|
||||||
|
self.qformer = Blip2QFormerModel(config.qformer_config)
|
||||||
|
|
||||||
|
# vision projection layer
|
||||||
|
self.vision_projection = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size)
|
||||||
|
|
||||||
|
# text projection layer
|
||||||
|
self.text_projection = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size)
|
||||||
|
|
||||||
|
# image text matching head
|
||||||
|
self.itm_head = nn.Linear(config.qformer_config.hidden_size, 2)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(BLIP2_IMAGE_TEXT_RETRIEVAL_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=Blip2ImageTextMatchingModelOutput, config_class=Blip2Config)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
use_image_text_matching_head: Optional[bool] = False,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, Blip2ImageTextMatchingModelOutput]:
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, Blip2ForImageTextRetrieval
|
||||||
|
|
||||||
|
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
>>> model = Blip2ForImageTextRetrieval.from_pretrained("Salesforce/blip2-itm-vit-g", torch_dtype=torch.float16)
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip2-itm-vit-g")
|
||||||
|
|
||||||
|
>>> model.to(device) # doctest: +IGNORE_RESULT
|
||||||
|
|
||||||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
>>> text = "two cats laying on a pink blanket"
|
||||||
|
|
||||||
|
>>> inputs = processor(images=image, text=text, return_tensors="pt").to(device, torch.float16)
|
||||||
|
>>> itm_out = model(**inputs, use_image_text_matching_head=True)
|
||||||
|
>>> logits_per_image = torch.nn.functional.softmax(itm_out.logits_per_image, dim=1)
|
||||||
|
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
||||||
|
|
||||||
|
>>> print(f"{probs[0][0]:.1%} that image 0 is not '{text}'")
|
||||||
|
26.9% that image 0 is not 'two cats laying on a pink blanket'
|
||||||
|
|
||||||
|
>>> print(f"{probs[0][1]:.1%} that image 0 is '{text}'")
|
||||||
|
73.0% that image 0 is 'two cats laying on a pink blanket'
|
||||||
|
|
||||||
|
>>> texts = ["a photo of a cat", "a photo of a dog"]
|
||||||
|
|
||||||
|
>>> inputs = processor(images=image, text=texts, return_tensors="pt").to(device, torch.float16)
|
||||||
|
>>> itc_out = model(**inputs, use_image_text_matching_head=False)
|
||||||
|
>>> logits_per_image = itc_out.logits_per_image # this is the image-text similarity score
|
||||||
|
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
||||||
|
|
||||||
|
>>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
|
||||||
|
55.3% that image 0 is 'a photo of a cat'
|
||||||
|
|
||||||
|
>>> print(f"{probs[0][1]:.1%} that image 0 is '{texts[1]}'")
|
||||||
|
44.7% that image 0 is 'a photo of a dog'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
vision_outputs = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_embeds = vision_outputs[0]
|
||||||
|
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||||
|
|
||||||
|
if use_image_text_matching_head:
|
||||||
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||||
|
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(query_tokens.device)
|
||||||
|
attention_mask = torch.cat([query_attention_mask, attention_mask], dim=1)
|
||||||
|
|
||||||
|
query_embeds = self.embeddings(
|
||||||
|
input_ids=input_ids,
|
||||||
|
query_embeds=query_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
text_outputs = self.qformer(
|
||||||
|
query_embeds=query_embeds,
|
||||||
|
query_length=query_tokens.shape[1],
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_hidden_states=image_embeds,
|
||||||
|
encoder_attention_mask=image_attention_mask,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
text_embeds = text_outputs[0] if not return_dict else text_outputs.last_hidden_state
|
||||||
|
|
||||||
|
output = self.itm_head(text_embeds[:, : query_tokens.size(1), :])
|
||||||
|
logits_per_image = output.mean(dim=1)
|
||||||
|
logits_per_text = logits_per_image.t()
|
||||||
|
else:
|
||||||
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||||
|
query_outputs = self.qformer(
|
||||||
|
query_embeds=query_tokens,
|
||||||
|
encoder_hidden_states=image_embeds,
|
||||||
|
encoder_attention_mask=image_attention_mask,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
image_embeds = query_outputs[0] if not return_dict else query_outputs.last_hidden_state
|
||||||
|
|
||||||
|
query_embeds = self.embeddings(
|
||||||
|
input_ids=input_ids,
|
||||||
|
)
|
||||||
|
text_outputs = self.qformer(
|
||||||
|
query_embeds=query_embeds,
|
||||||
|
query_length=0,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
question_embeds = text_outputs[0] if not return_dict else text_outputs.last_hidden_state
|
||||||
|
|
||||||
|
# normalized features
|
||||||
|
image_embeds = nn.functional.normalize(self.vision_projection(image_embeds), dim=-1)
|
||||||
|
text_embeds = nn.functional.normalize(self.text_projection(question_embeds[:, 0, :]), dim=-1)
|
||||||
|
|
||||||
|
# cosine similarity as logits
|
||||||
|
logits_per_image = torch.matmul(image_embeds, text_embeds.t())
|
||||||
|
logits_per_image, _ = logits_per_image.max(dim=1)
|
||||||
|
|
||||||
|
logits_per_text = logits_per_image.t()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
||||||
|
return output
|
||||||
|
|
||||||
|
return Blip2ImageTextMatchingModelOutput(
|
||||||
|
logits_per_image=logits_per_image,
|
||||||
|
logits_per_text=logits_per_text,
|
||||||
|
text_embeds=text_embeds,
|
||||||
|
image_embeds=image_embeds,
|
||||||
|
text_model_output=text_outputs,
|
||||||
|
vision_model_output=vision_outputs,
|
||||||
|
)
|
||||||
|
|||||||
@@ -195,19 +195,19 @@ class ClapOutput(ModelOutput):
|
|||||||
Args:
|
Args:
|
||||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
||||||
Contrastive loss for audio-text similarity.
|
Contrastive loss for audio-text similarity.
|
||||||
logits_per_audio:(`torch.FloatTensor` of shape `(audio_batch_size, text_batch_size)`):
|
logits_per_audio (`torch.FloatTensor` of shape `(audio_batch_size, text_batch_size)`):
|
||||||
The scaled dot product scores between `audio_embeds` and `text_embeds`. This represents the audio-text
|
The scaled dot product scores between `audio_embeds` and `text_embeds`. This represents the audio-text
|
||||||
similarity scores.
|
similarity scores.
|
||||||
logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, audio_batch_size)`):
|
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, audio_batch_size)`):
|
||||||
The scaled dot product scores between `text_embeds` and `audio_embeds`. This represents the text-audio
|
The scaled dot product scores between `text_embeds` and `audio_embeds`. This represents the text-audio
|
||||||
similarity scores.
|
similarity scores.
|
||||||
text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||||
The text embeddings obtained by applying the projection layer to the pooled output of [`ClapTextModel`].
|
The text embeddings obtained by applying the projection layer to the pooled output of [`ClapTextModel`].
|
||||||
audio_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
audio_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||||
The audio embeddings obtained by applying the projection layer to the pooled output of [`ClapAudioModel`].
|
The audio embeddings obtained by applying the projection layer to the pooled output of [`ClapAudioModel`].
|
||||||
text_model_output(`BaseModelOutputWithPooling`):
|
text_model_output (`BaseModelOutputWithPooling`):
|
||||||
The output of the [`ClapTextModel`].
|
The output of the [`ClapTextModel`].
|
||||||
audio_model_output(`BaseModelOutputWithPooling`):
|
audio_model_output (`BaseModelOutputWithPooling`):
|
||||||
The output of the [`ClapAudioModel`].
|
The output of the [`ClapAudioModel`].
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -142,19 +142,19 @@ class CLIPOutput(ModelOutput):
|
|||||||
Args:
|
Args:
|
||||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
||||||
Contrastive loss for image-text similarity.
|
Contrastive loss for image-text similarity.
|
||||||
logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
||||||
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
||||||
similarity scores.
|
similarity scores.
|
||||||
logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
||||||
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
||||||
similarity scores.
|
similarity scores.
|
||||||
text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||||
The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
|
The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
|
||||||
image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||||
The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
||||||
text_model_output(`BaseModelOutputWithPooling`):
|
text_model_output (`BaseModelOutputWithPooling`):
|
||||||
The output of the [`CLIPTextModel`].
|
The output of the [`CLIPTextModel`].
|
||||||
vision_model_output(`BaseModelOutputWithPooling`):
|
vision_model_output (`BaseModelOutputWithPooling`):
|
||||||
The output of the [`CLIPVisionModel`].
|
The output of the [`CLIPVisionModel`].
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -63,19 +63,19 @@ class CLIPSegOutput(ModelOutput):
|
|||||||
Args:
|
Args:
|
||||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
||||||
Contrastive loss for image-text similarity.
|
Contrastive loss for image-text similarity.
|
||||||
logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
||||||
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
||||||
similarity scores.
|
similarity scores.
|
||||||
logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
||||||
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
||||||
similarity scores.
|
similarity scores.
|
||||||
text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||||
The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPSegTextModel`].
|
The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPSegTextModel`].
|
||||||
image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||||
The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPSegVisionModel`].
|
The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPSegVisionModel`].
|
||||||
text_model_output(`BaseModelOutputWithPooling`):
|
text_model_output (`BaseModelOutputWithPooling`):
|
||||||
The output of the [`CLIPSegTextModel`].
|
The output of the [`CLIPSegTextModel`].
|
||||||
vision_model_output(`BaseModelOutputWithPooling`):
|
vision_model_output (`BaseModelOutputWithPooling`):
|
||||||
The output of the [`CLIPSegVisionModel`].
|
The output of the [`CLIPSegVisionModel`].
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -215,19 +215,19 @@ class SiglipOutput(ModelOutput):
|
|||||||
Args:
|
Args:
|
||||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
||||||
Contrastive loss for image-text similarity.
|
Contrastive loss for image-text similarity.
|
||||||
logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
||||||
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
||||||
similarity scores.
|
similarity scores.
|
||||||
logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
||||||
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
||||||
similarity scores.
|
similarity scores.
|
||||||
text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||||
The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
|
The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
|
||||||
image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||||
The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
|
The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
|
||||||
text_model_output(`BaseModelOutputWithPooling`):
|
text_model_output (`BaseModelOutputWithPooling`):
|
||||||
The output of the [`SiglipTextModel`].
|
The output of the [`SiglipTextModel`].
|
||||||
vision_model_output(`BaseModelOutputWithPooling`):
|
vision_model_output (`BaseModelOutputWithPooling`):
|
||||||
The output of the [`SiglipVisionModel`].
|
The output of the [`SiglipVisionModel`].
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -97,6 +97,9 @@ class ZeroShotImageClassificationPipeline(Pipeline):
|
|||||||
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
||||||
the call may block forever.
|
the call may block forever.
|
||||||
|
|
||||||
|
tokenizer_kwargs (`dict`, *optional*):
|
||||||
|
Additional dictionary of keyword arguments passed along to the tokenizer.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
A list of dictionaries containing result, one dictionary per proposed label. The dictionaries contain the
|
A list of dictionaries containing result, one dictionary per proposed label. The dictionaries contain the
|
||||||
following keys:
|
following keys:
|
||||||
@@ -106,7 +109,7 @@ class ZeroShotImageClassificationPipeline(Pipeline):
|
|||||||
"""
|
"""
|
||||||
return super().__call__(images, **kwargs)
|
return super().__call__(images, **kwargs)
|
||||||
|
|
||||||
def _sanitize_parameters(self, **kwargs):
|
def _sanitize_parameters(self, tokenizer_kwargs=None, **kwargs):
|
||||||
preprocess_params = {}
|
preprocess_params = {}
|
||||||
if "candidate_labels" in kwargs:
|
if "candidate_labels" in kwargs:
|
||||||
preprocess_params["candidate_labels"] = kwargs["candidate_labels"]
|
preprocess_params["candidate_labels"] = kwargs["candidate_labels"]
|
||||||
@@ -114,10 +117,21 @@ class ZeroShotImageClassificationPipeline(Pipeline):
|
|||||||
preprocess_params["timeout"] = kwargs["timeout"]
|
preprocess_params["timeout"] = kwargs["timeout"]
|
||||||
if "hypothesis_template" in kwargs:
|
if "hypothesis_template" in kwargs:
|
||||||
preprocess_params["hypothesis_template"] = kwargs["hypothesis_template"]
|
preprocess_params["hypothesis_template"] = kwargs["hypothesis_template"]
|
||||||
|
if tokenizer_kwargs is not None:
|
||||||
|
preprocess_params["tokenizer_kwargs"] = tokenizer_kwargs
|
||||||
|
|
||||||
return preprocess_params, {}, {}
|
return preprocess_params, {}, {}
|
||||||
|
|
||||||
def preprocess(self, image, candidate_labels=None, hypothesis_template="This is a photo of {}.", timeout=None):
|
def preprocess(
|
||||||
|
self,
|
||||||
|
image,
|
||||||
|
candidate_labels=None,
|
||||||
|
hypothesis_template="This is a photo of {}.",
|
||||||
|
timeout=None,
|
||||||
|
tokenizer_kwargs=None,
|
||||||
|
):
|
||||||
|
if tokenizer_kwargs is None:
|
||||||
|
tokenizer_kwargs = {}
|
||||||
image = load_image(image, timeout=timeout)
|
image = load_image(image, timeout=timeout)
|
||||||
inputs = self.image_processor(images=[image], return_tensors=self.framework)
|
inputs = self.image_processor(images=[image], return_tensors=self.framework)
|
||||||
if self.framework == "pt":
|
if self.framework == "pt":
|
||||||
@@ -125,7 +139,7 @@ class ZeroShotImageClassificationPipeline(Pipeline):
|
|||||||
inputs["candidate_labels"] = candidate_labels
|
inputs["candidate_labels"] = candidate_labels
|
||||||
sequences = [hypothesis_template.format(x) for x in candidate_labels]
|
sequences = [hypothesis_template.format(x) for x in candidate_labels]
|
||||||
padding = "max_length" if self.model.config.model_type == "siglip" else True
|
padding = "max_length" if self.model.config.model_type == "siglip" else True
|
||||||
text_inputs = self.tokenizer(sequences, return_tensors=self.framework, padding=padding)
|
text_inputs = self.tokenizer(sequences, return_tensors=self.framework, padding=padding, **tokenizer_kwargs)
|
||||||
inputs["text_inputs"] = [text_inputs]
|
inputs["text_inputs"] = [text_inputs]
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
|||||||
@@ -1603,6 +1603,13 @@ class Blip2ForConditionalGeneration(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class Blip2ForImageTextRetrieval(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class Blip2Model(metaclass=DummyObject):
|
class Blip2Model(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -1624,6 +1631,13 @@ class Blip2QFormerModel(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class Blip2TextModelWithProjection(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class Blip2VisionModel(metaclass=DummyObject):
|
class Blip2VisionModel(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -1631,6 +1645,13 @@ class Blip2VisionModel(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class Blip2VisionModelWithProjection(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class BloomForCausalLM(metaclass=DummyObject):
|
class BloomForCausalLM(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ import requests
|
|||||||
from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig
|
from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_torch,
|
require_torch,
|
||||||
|
require_torch_fp16,
|
||||||
|
require_torch_gpu,
|
||||||
require_torch_multi_accelerator,
|
require_torch_multi_accelerator,
|
||||||
require_vision,
|
require_vision,
|
||||||
slow,
|
slow,
|
||||||
@@ -47,7 +49,14 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import Blip2ForConditionalGeneration, Blip2Model, Blip2VisionModel
|
from transformers import (
|
||||||
|
Blip2ForConditionalGeneration,
|
||||||
|
Blip2ForImageTextRetrieval,
|
||||||
|
Blip2Model,
|
||||||
|
Blip2TextModelWithProjection,
|
||||||
|
Blip2VisionModel,
|
||||||
|
Blip2VisionModelWithProjection,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@@ -243,6 +252,7 @@ class Blip2QFormerModelTester:
|
|||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
bos_token_id=0,
|
bos_token_id=0,
|
||||||
scope=None,
|
scope=None,
|
||||||
|
use_qformer_text_input=False,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@@ -262,6 +272,7 @@ class Blip2QFormerModelTester:
|
|||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
self.bos_token_id = bos_token_id
|
self.bos_token_id = bos_token_id
|
||||||
|
self.use_qformer_text_input = use_qformer_text_input
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
@@ -294,6 +305,7 @@ class Blip2QFormerModelTester:
|
|||||||
max_position_embeddings=self.max_position_embeddings,
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
initializer_range=self.initializer_range,
|
initializer_range=self.initializer_range,
|
||||||
bos_token_id=self.bos_token_id,
|
bos_token_id=self.bos_token_id,
|
||||||
|
use_qformer_text_input=self.use_qformer_text_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -489,7 +501,7 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
|
|||||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||||
|
|
||||||
def test_load_vision_qformer_text_config(self):
|
def test_load_vision_qformer_text_config(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
# Save Blip2Config and check if we can load Blip2VisionConfig from it
|
# Save Blip2Config and check if we can load Blip2VisionConfig from it
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
@@ -704,6 +716,16 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
|
|||||||
test_attention_outputs = False
|
test_attention_outputs = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
|
|
||||||
|
# TODO: Fix the failed tests
|
||||||
|
def is_pipeline_test_to_skip(
|
||||||
|
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||||
|
):
|
||||||
|
if pipeline_test_casse_name == "VisualQuestionAnsweringPipelineTests":
|
||||||
|
# Get `RuntimeError: "LayerNormKernelImpl" not implemented for 'Half'`.
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = Blip2ModelTester(self)
|
self.model_tester = Blip2ModelTester(self)
|
||||||
|
|
||||||
@@ -752,7 +774,7 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
|
|||||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||||
|
|
||||||
def test_load_vision_qformer_text_config(self):
|
def test_load_vision_qformer_text_config(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
# Save Blip2Config and check if we can load Blip2VisionConfig from it
|
# Save Blip2Config and check if we can load Blip2VisionConfig from it
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
@@ -840,6 +862,549 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Blip2TextModelWithProjectionTester:
|
||||||
|
def __init__(self, parent, vision_kwargs=None, qformer_kwargs=None, is_training=True):
|
||||||
|
if vision_kwargs is None:
|
||||||
|
vision_kwargs = {}
|
||||||
|
if qformer_kwargs is None:
|
||||||
|
qformer_kwargs = {"use_qformer_text_input": True}
|
||||||
|
|
||||||
|
self.parent = parent
|
||||||
|
self.vision_model_tester = Blip2VisionModelTester(parent, **vision_kwargs)
|
||||||
|
self.qformer_model_tester = Blip2QFormerModelTester(parent, **qformer_kwargs)
|
||||||
|
self.is_training = is_training
|
||||||
|
self.batch_size = self.vision_model_tester.batch_size # need bs for batching_equivalence test
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return Blip2Config.from_vision_qformer_text_configs(
|
||||||
|
vision_config=self.vision_model_tester.get_config(),
|
||||||
|
qformer_config=self.qformer_model_tester.get_config(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
_, input_ids, attention_mask = self.qformer_model_tester.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return config, input_ids, attention_mask
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
config, input_ids, attention_mask = config_and_inputs
|
||||||
|
inputs_dict = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
def create_and_check_model(self, config, input_ids, attention_mask):
|
||||||
|
model = Blip2TextModelWithProjection(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
result = model(input_ids, attention_mask=attention_mask, output_attentions=True, output_hidden_states=True)
|
||||||
|
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.last_hidden_state.shape,
|
||||||
|
(self.vision_model_tester.batch_size, input_ids.shape[1], self.qformer_model_tester.hidden_size),
|
||||||
|
)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.text_embeds.shape,
|
||||||
|
(
|
||||||
|
self.vision_model_tester.batch_size,
|
||||||
|
input_ids.shape[1],
|
||||||
|
config.image_text_hidden_size,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
result2 = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
return_dict=not config.use_return_dict,
|
||||||
|
output_attentions=True,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.parent.assertTrue(torch.allclose(result.text_embeds, result2[0]))
|
||||||
|
self.parent.assertTrue(torch.allclose(result.last_hidden_state, result2[1]))
|
||||||
|
self.parent.assertTrue(torch.allclose(result.hidden_states[0], result2[2][0]))
|
||||||
|
self.parent.assertTrue(torch.allclose(result.hidden_states[1], result2[2][1]))
|
||||||
|
self.parent.assertTrue(torch.allclose(result.attentions[0], result2[3][0]))
|
||||||
|
self.parent.assertTrue(torch.allclose(result.attentions[1], result2[3][1]))
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class Blip2TextModelWithProjectionTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (Blip2TextModelWithProjection,) if is_torch_available() else ()
|
||||||
|
fx_compatible = False
|
||||||
|
test_pruning = False
|
||||||
|
test_head_masking = False
|
||||||
|
|
||||||
|
test_resize_embeddings = False
|
||||||
|
test_attention_outputs = False
|
||||||
|
test_torchscript = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = Blip2TextModelWithProjectionTester(self)
|
||||||
|
|
||||||
|
def test_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
@unittest.skip(reason="Training is not yet supported")
|
||||||
|
def test_training(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Training is not yet supported")
|
||||||
|
def test_training_gradient_checkpointing(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||||
|
def test_hidden_states_output(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Blip2TextModelWithProjection does not use inputs_embeds")
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Blip2TextModelWithProjection does not support input and output embeddings")
|
||||||
|
def test_model_get_set_embeddings(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
||||||
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Blip2TextModelWithProjection does not have input/output embeddings")
|
||||||
|
def test_model_common_attributes(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Blip2TextModelWithProjection has no base class and is not available in MODEL_MAPPING")
|
||||||
|
def test_save_load_fast_init_from_base(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Blip2TextModelWithProjection has no base class and is not available in MODEL_MAPPING")
|
||||||
|
def test_save_load_fast_init_to_base(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_forward_signature(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
signature = inspect.signature(model.forward)
|
||||||
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||||
|
arg_names = [*signature.parameters.keys()]
|
||||||
|
|
||||||
|
expected_arg_names = ["input_ids", "attention_mask", "position_ids"]
|
||||||
|
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_model_from_pretrained(self):
|
||||||
|
model_name = "Salesforce/blip2-itm-vit-g"
|
||||||
|
model = Blip2TextModelWithProjection.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
self.assertTrue(hasattr(model, "text_projection"))
|
||||||
|
|
||||||
|
_, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
outputs.text_embeds.shape,
|
||||||
|
(
|
||||||
|
self.model_tester.qformer_model_tester.batch_size,
|
||||||
|
input_ids.shape[1],
|
||||||
|
model.config.image_text_hidden_size,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Blip2VisionModelWithProjectionTester:
|
||||||
|
def __init__(self, parent, vision_kwargs=None, qformer_kwargs=None, is_training=True):
|
||||||
|
if vision_kwargs is None:
|
||||||
|
vision_kwargs = {}
|
||||||
|
if qformer_kwargs is None:
|
||||||
|
qformer_kwargs = {"use_qformer_text_input": True}
|
||||||
|
|
||||||
|
self.parent = parent
|
||||||
|
self.vision_model_tester = Blip2VisionModelTester(parent, **vision_kwargs)
|
||||||
|
self.qformer_model_tester = Blip2QFormerModelTester(parent, **qformer_kwargs)
|
||||||
|
self.is_training = is_training
|
||||||
|
self.num_hidden_layers = self.vision_model_tester.num_hidden_layers
|
||||||
|
self.num_attention_heads = self.vision_model_tester.num_attention_heads
|
||||||
|
self.seq_length = self.vision_model_tester.seq_length
|
||||||
|
self.hidden_size = self.vision_model_tester.hidden_size
|
||||||
|
self.batch_size = self.vision_model_tester.batch_size # need bs for batching_equivalence test
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return Blip2Config.from_vision_qformer_text_configs(
|
||||||
|
vision_config=self.vision_model_tester.get_config(),
|
||||||
|
qformer_config=self.qformer_model_tester.get_config(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
_, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return config, pixel_values
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
config, pixel_values = config_and_inputs
|
||||||
|
inputs_dict = {"pixel_values": pixel_values}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
def create_and_check_model(self, config, pixel_values):
|
||||||
|
model = Blip2VisionModelWithProjection(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
result = model(pixel_values, output_attentions=True, output_hidden_states=True)
|
||||||
|
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.last_hidden_state.shape,
|
||||||
|
(
|
||||||
|
self.vision_model_tester.batch_size,
|
||||||
|
self.vision_model_tester.seq_length,
|
||||||
|
self.qformer_model_tester.hidden_size,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.image_embeds.shape,
|
||||||
|
(
|
||||||
|
self.vision_model_tester.batch_size,
|
||||||
|
config.vision_config.hidden_size,
|
||||||
|
config.image_text_hidden_size,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
result2 = model(
|
||||||
|
pixel_values,
|
||||||
|
return_dict=not config.use_return_dict,
|
||||||
|
output_attentions=True,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.parent.assertTrue(torch.allclose(result.image_embeds, result2[0]))
|
||||||
|
self.parent.assertTrue(torch.allclose(result.last_hidden_state, result2[1]))
|
||||||
|
self.parent.assertTrue(torch.allclose(result.hidden_states[0], result2[2][0]))
|
||||||
|
self.parent.assertTrue(torch.allclose(result.hidden_states[1], result2[2][1]))
|
||||||
|
self.parent.assertTrue(torch.allclose(result.attentions[0], result2[3][0]))
|
||||||
|
self.parent.assertTrue(torch.allclose(result.attentions[1], result2[3][1]))
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class Blip2VisionModelWithProjectionTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (Blip2VisionModelWithProjection,) if is_torch_available() else ()
|
||||||
|
fx_compatible = False
|
||||||
|
test_pruning = False
|
||||||
|
test_head_masking = False
|
||||||
|
|
||||||
|
test_resize_embeddings = False
|
||||||
|
test_torchscript = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = Blip2VisionModelWithProjectionTester(self)
|
||||||
|
|
||||||
|
def test_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
@unittest.skip(reason="Training is not yet supported")
|
||||||
|
def test_training(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Training is not yet supported")
|
||||||
|
def test_training_gradient_checkpointing(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Training is not yet supported")
|
||||||
|
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Training is not yet supported")
|
||||||
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Blip2VisionModelWithProjection does not use inputs_embeds")
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Blip2VisionModelWithProjection does not support input and output embeddings")
|
||||||
|
def test_model_get_set_embeddings(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
||||||
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_model_common_attributes(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
|
||||||
|
x = model.get_output_embeddings()
|
||||||
|
self.assertTrue(x is None or isinstance(x, nn.Linear))
|
||||||
|
|
||||||
|
@unittest.skip(reason="Blip2VisionModelWithProjection has no base class and is not available in MODEL_MAPPING")
|
||||||
|
def test_save_load_fast_init_from_base(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Blip2VisionModelWithProjection has no base class and is not available in MODEL_MAPPING")
|
||||||
|
def test_save_load_fast_init_to_base(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_forward_signature(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
signature = inspect.signature(model.forward)
|
||||||
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||||
|
arg_names = [*signature.parameters.keys()]
|
||||||
|
|
||||||
|
expected_arg_names = ["pixel_values"]
|
||||||
|
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_model_from_pretrained(self):
|
||||||
|
model_name = "Salesforce/blip2-itm-vit-g"
|
||||||
|
model = Blip2VisionModelWithProjection.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
self.assertTrue(hasattr(model, "vision_projection"))
|
||||||
|
|
||||||
|
_, pixel_values = self.model_tester.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(pixel_values=pixel_values)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
outputs.image_embeds.shape,
|
||||||
|
(
|
||||||
|
self.model_tester.vision_model_tester.batch_size,
|
||||||
|
model.config.num_query_tokens,
|
||||||
|
model.config.image_text_hidden_size,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Blip2TextRetrievalModelTester:
|
||||||
|
def __init__(self, parent, vision_kwargs=None, qformer_kwargs=None, is_training=True):
|
||||||
|
if vision_kwargs is None:
|
||||||
|
vision_kwargs = {}
|
||||||
|
if qformer_kwargs is None:
|
||||||
|
qformer_kwargs = {"use_qformer_text_input": True}
|
||||||
|
|
||||||
|
self.parent = parent
|
||||||
|
self.vision_model_tester = Blip2VisionModelTester(parent, **vision_kwargs)
|
||||||
|
self.qformer_model_tester = Blip2QFormerModelTester(parent, **qformer_kwargs)
|
||||||
|
self.is_training = is_training
|
||||||
|
self.batch_size = self.vision_model_tester.batch_size # need bs for batching_equivalence test
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return Blip2Config.from_vision_qformer_text_configs(
|
||||||
|
vision_config=self.vision_model_tester.get_config(),
|
||||||
|
qformer_config=self.qformer_model_tester.get_config(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
_, input_ids, attention_mask = self.qformer_model_tester.prepare_config_and_inputs()
|
||||||
|
_, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return config, input_ids, attention_mask, pixel_values
|
||||||
|
|
||||||
|
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
|
||||||
|
model = Blip2ForImageTextRetrieval(config).to(torch_device).eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
result = model(pixel_values, input_ids, attention_mask, use_image_text_matching_head=True)
|
||||||
|
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.logits_per_image.shape,
|
||||||
|
(self.vision_model_tester.batch_size, 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
result = model(pixel_values, input_ids, attention_mask)
|
||||||
|
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.logits_per_image.shape,
|
||||||
|
(self.vision_model_tester.batch_size, self.qformer_model_tester.batch_size),
|
||||||
|
)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.logits_per_text.shape, (self.qformer_model_tester.batch_size, self.vision_model_tester.batch_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
||||||
|
inputs_dict = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"pixel_values": pixel_values,
|
||||||
|
}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class Blip2TextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (Blip2ForImageTextRetrieval,) if is_torch_available() else ()
|
||||||
|
fx_compatible = False
|
||||||
|
test_head_masking = False
|
||||||
|
test_pruning = False
|
||||||
|
test_resize_embeddings = False
|
||||||
|
test_attention_outputs = False
|
||||||
|
test_torchscript = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = Blip2TextRetrievalModelTester(self)
|
||||||
|
|
||||||
|
def test_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||||
|
def test_hidden_states_output(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Blip2ForImageTextRetrieval does not support input and output embeddings")
|
||||||
|
def test_model_get_set_embeddings(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
||||||
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Blip2Model does not have input/output embeddings")
|
||||||
|
def test_model_common_attributes(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_forward_signature(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
signature = inspect.signature(model.forward)
|
||||||
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||||
|
arg_names = [*signature.parameters.keys()]
|
||||||
|
|
||||||
|
expected_arg_names = ["pixel_values", "input_ids", "attention_mask"]
|
||||||
|
expected_arg_names.extend(
|
||||||
|
["use_image_text_matching_head"] if "use_image_text_matching_head" in arg_names else []
|
||||||
|
)
|
||||||
|
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||||
|
|
||||||
|
def test_load_vision_qformer_text_config(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
# Save Blip2Config and check if we can load Blip2VisionConfig from it
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
|
config.save_pretrained(tmp_dir_name)
|
||||||
|
vision_config = Blip2VisionConfig.from_pretrained(tmp_dir_name)
|
||||||
|
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
|
||||||
|
|
||||||
|
# Save Blip2Config and check if we can load Blip2QFormerConfig from it
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
|
config.save_pretrained(tmp_dir_name)
|
||||||
|
qformer_config = Blip2QFormerConfig.from_pretrained(tmp_dir_name)
|
||||||
|
self.assertDictEqual(config.qformer_config.to_dict(), qformer_config.to_dict())
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_model_from_pretrained(self):
|
||||||
|
model_name = "Salesforce/blip2-itm-vit-g"
|
||||||
|
model = Blip2ForImageTextRetrieval.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
_, input_ids, attention_mask, pixel_values = self.model_tester.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
use_image_text_matching_head=True,
|
||||||
|
)
|
||||||
|
self.assertEqual(outputs.logits_per_image.shape, (self.model_tester.qformer_model_tester.batch_size, 2))
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs.logits_per_image.shape,
|
||||||
|
(self.model_tester.vision_model_tester.batch_size, self.model_tester.qformer_model_tester.batch_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skip(reason="Training is not yet supported")
|
||||||
|
def test_training(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Training is not yet supported")
|
||||||
|
def test_training_gradient_checkpointing(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Training is not yet supported")
|
||||||
|
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Training is not yet supported")
|
||||||
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_initialization(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config)
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
# check if `logit_scale` is initilized as per the original implementation
|
||||||
|
if name == "logit_scale":
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
param.data.item(),
|
||||||
|
np.log(1 / 0.07),
|
||||||
|
delta=1e-3,
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
elif name == "temp":
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
param.data.item(),
|
||||||
|
0.07,
|
||||||
|
delta=1e-3,
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.assertIn(
|
||||||
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||||
|
[0.0, 1.0],
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# We will verify our results on an image of cute cats
|
# We will verify our results on an image of cute cats
|
||||||
def prepare_img():
|
def prepare_img():
|
||||||
url = "https://huggingface.co/hf-internal-testing/blip-test-image/resolve/main/demo.jpg"
|
url = "https://huggingface.co/hf-internal-testing/blip-test-image/resolve/main/demo.jpg"
|
||||||
@@ -984,7 +1549,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
|||||||
prompt = "Question: which city is this? Answer:"
|
prompt = "Question: which city is this? Answer:"
|
||||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(0, dtype=torch.float16)
|
inputs = processor(images=image, text=prompt, return_tensors="pt").to(0, dtype=torch.float16)
|
||||||
|
|
||||||
predictions = model.generate(**inputs)
|
predictions = model.generate(**inputs, max_new_tokens=11)
|
||||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||||
|
|
||||||
# Test output
|
# Test output
|
||||||
@@ -1063,3 +1628,93 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
|||||||
generated_text_expanded = processor.batch_decode(predictions_expanded, skip_special_tokens=True)[0].strip()
|
generated_text_expanded = processor.batch_decode(predictions_expanded, skip_special_tokens=True)[0].strip()
|
||||||
|
|
||||||
self.assertTrue(generated_text_expanded == generated_text)
|
self.assertTrue(generated_text_expanded == generated_text)
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_inference_itm(self):
|
||||||
|
model_name = "Salesforce/blip2-itm-vit-g"
|
||||||
|
processor = Blip2Processor.from_pretrained(model_name)
|
||||||
|
model = Blip2ForImageTextRetrieval.from_pretrained(model_name).to(torch_device)
|
||||||
|
|
||||||
|
image = prepare_img()
|
||||||
|
text = "A woman and her dog sitting in a beach"
|
||||||
|
inputs = processor(images=image, text=text, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
out_itm = model(**inputs, use_image_text_matching_head=True)
|
||||||
|
out = model(**inputs)
|
||||||
|
|
||||||
|
# verify
|
||||||
|
expected_scores = torch.Tensor([[0.0238, 0.9762]])
|
||||||
|
self.assertTrue(torch.allclose(torch.nn.Softmax()(out_itm[0].cpu()), expected_scores, rtol=1e-3, atol=1e-3))
|
||||||
|
self.assertTrue(torch.allclose(out[0].cpu(), torch.Tensor([[0.4406]]), rtol=1e-3, atol=1e-3))
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_torch_fp16
|
||||||
|
def test_inference_itm_fp16(self):
|
||||||
|
model_name = "Salesforce/blip2-itm-vit-g"
|
||||||
|
processor = Blip2Processor.from_pretrained(model_name)
|
||||||
|
model = Blip2ForImageTextRetrieval.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device)
|
||||||
|
|
||||||
|
image = prepare_img()
|
||||||
|
text = "A woman and her dog sitting in a beach"
|
||||||
|
inputs = processor(images=image, text=text, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
out_itm = model(**inputs, use_image_text_matching_head=True)
|
||||||
|
out = model(**inputs)
|
||||||
|
|
||||||
|
# verify
|
||||||
|
expected_scores = torch.Tensor([[0.0239, 0.9761]])
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(torch.nn.Softmax()(out_itm[0].cpu().float()), expected_scores, rtol=1e-3, atol=1e-3)
|
||||||
|
)
|
||||||
|
self.assertTrue(torch.allclose(out[0].cpu().float(), torch.Tensor([[0.4406]]), rtol=1e-3, atol=1e-3))
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_torch_fp16
|
||||||
|
def test_inference_vision_with_projection_fp16(self):
|
||||||
|
model_name = "Salesforce/blip2-itm-vit-g"
|
||||||
|
processor = Blip2Processor.from_pretrained(model_name)
|
||||||
|
model = Blip2VisionModelWithProjection.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device)
|
||||||
|
|
||||||
|
image = prepare_img()
|
||||||
|
inputs = processor(images=image, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
out = model(**inputs)
|
||||||
|
|
||||||
|
# verify
|
||||||
|
expected_image_embeds = [
|
||||||
|
-0.093994140625,
|
||||||
|
-0.075927734375,
|
||||||
|
0.031890869140625,
|
||||||
|
0.053009033203125,
|
||||||
|
0.0352783203125,
|
||||||
|
-0.01190185546875,
|
||||||
|
]
|
||||||
|
self.assertTrue(np.allclose(out.image_embeds[0][0][:6].tolist(), expected_image_embeds, atol=1e-3))
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_torch_fp16
|
||||||
|
def test_inference_text_with_projection_fp16(self):
|
||||||
|
model_name = "Salesforce/blip2-itm-vit-g"
|
||||||
|
processor = Blip2Processor.from_pretrained(model_name)
|
||||||
|
model = Blip2TextModelWithProjection.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device)
|
||||||
|
|
||||||
|
inputs = processor(text="a woman sitting on the beach with a dog", padding=True, return_tensors="pt").to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
out = model(**inputs)
|
||||||
|
|
||||||
|
# verify
|
||||||
|
expected_text_embeds = [
|
||||||
|
-0.1082763671875,
|
||||||
|
0.053192138671875,
|
||||||
|
-0.02825927734375,
|
||||||
|
0.0169830322265625,
|
||||||
|
0.08648681640625,
|
||||||
|
-0.04656982421875,
|
||||||
|
]
|
||||||
|
self.assertTrue(np.allclose(out.text_embeds[0][0][:6].tolist(), expected_text_embeds, atol=1e-3))
|
||||||
|
|||||||
@@ -279,3 +279,46 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
* 5,
|
* 5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_blip2_model_pt(self):
|
||||||
|
image_classifier = pipeline(
|
||||||
|
task="zero-shot-image-classification",
|
||||||
|
model="Salesforce/blip2-itm-vit-g",
|
||||||
|
)
|
||||||
|
# This is an image of 2 cats with remotes and no planes
|
||||||
|
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||||
|
output = image_classifier(
|
||||||
|
image,
|
||||||
|
candidate_labels=["2 cats", "a plane", "a remote"],
|
||||||
|
tokenizer_kwargs={"return_token_type_ids": False},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(output),
|
||||||
|
[
|
||||||
|
{"score": 0.369, "label": "2 cats"},
|
||||||
|
{"score": 0.333, "label": "a remote"},
|
||||||
|
{"score": 0.297, "label": "a plane"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
output = image_classifier(
|
||||||
|
[image] * 5,
|
||||||
|
candidate_labels=["2 cats", "a plane", "a remote"],
|
||||||
|
batch_size=2,
|
||||||
|
tokenizer_kwargs={"return_token_type_ids": False},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(output),
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{"score": 0.369, "label": "2 cats"},
|
||||||
|
{"score": 0.333, "label": "a remote"},
|
||||||
|
{"score": 0.297, "label": "a plane"},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
* 5,
|
||||||
|
)
|
||||||
|
|||||||
@@ -169,6 +169,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
|||||||
"ClapAudioModel",
|
"ClapAudioModel",
|
||||||
"ClapAudioModelWithProjection",
|
"ClapAudioModelWithProjection",
|
||||||
"Blip2ForConditionalGeneration",
|
"Blip2ForConditionalGeneration",
|
||||||
|
"Blip2TextModelWithProjection",
|
||||||
|
"Blip2VisionModelWithProjection",
|
||||||
"Blip2QFormerModel",
|
"Blip2QFormerModel",
|
||||||
"Blip2VisionModel",
|
"Blip2VisionModel",
|
||||||
"ErnieMForInformationExtraction",
|
"ErnieMForInformationExtraction",
|
||||||
|
|||||||
Reference in New Issue
Block a user