🚨 Add Blip2ForImageTextRetrieval (#29261)

* add Blip2ForImageTextRetrieval

* use one line and remove unnecessary space in tests

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* use  value from the config, rather than hardcoded

* change order of params in Blip2QFormerModel.forward

* update docstring

* fix style

* update test_inference_opt

* move embeddings out of Blip2QFormerModel

* remove from_vision_qformer_configs

* remove autocast float16 in Blip2QFormerModel

* rename fiels into vision_projection,text_projection,use_image_text_matching_head

* use CLIPOutput for  Blip2ImageTextMatchingModelOutput

* remove past_key_values_length from Blip2TextEmbeddings

* fix small typo in the CLIPOutput docstring

* add Blip2ForImageTextRetrieval to Zero Shot Image Classification mapping

* update docstring and add require_torch_fp16

* rollback test_inference_opt

* use use_image_text_matching_head=True in convert

* skip test_model_get_set_embeddings

* fix create_rename_keys error on new itm fields

* revert to do  scale after dot product between "query" and "key"

* fix ValueError on convert script for blip2-opt-2.7b

* update org of paths to Salesforce

* add is_pipeline_test_to_skip for VisualQuestionAnsweringPipelineTests

* [run_slow] blip_2

* removed Blip2ForImageTextRetrieval from IGNORE_NON_AUTO_CONFIGURED

* fix docstring of Blip2ImageTextMatchingModelOutput

* [run_slow] blip_2

* fix multi-gpu tests

* [run_slow] blip_2

* [run_slow] blip_2

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Juan Pizarro
2024-08-27 19:50:27 +02:00
committed by GitHub
parent 27903de7ec
commit 7591ca5bc5
17 changed files with 1568 additions and 101 deletions

View File

@@ -24,6 +24,8 @@ import requests
from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig
from transformers.testing_utils import (
require_torch,
require_torch_fp16,
require_torch_gpu,
require_torch_multi_accelerator,
require_vision,
slow,
@@ -47,7 +49,14 @@ if is_torch_available():
import torch
from torch import nn
from transformers import Blip2ForConditionalGeneration, Blip2Model, Blip2VisionModel
from transformers import (
Blip2ForConditionalGeneration,
Blip2ForImageTextRetrieval,
Blip2Model,
Blip2TextModelWithProjection,
Blip2VisionModel,
Blip2VisionModelWithProjection,
)
if is_vision_available():
@@ -243,6 +252,7 @@ class Blip2QFormerModelTester:
initializer_range=0.02,
bos_token_id=0,
scope=None,
use_qformer_text_input=False,
):
self.parent = parent
self.batch_size = batch_size
@@ -262,6 +272,7 @@ class Blip2QFormerModelTester:
self.initializer_range = initializer_range
self.scope = scope
self.bos_token_id = bos_token_id
self.use_qformer_text_input = use_qformer_text_input
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@@ -294,6 +305,7 @@ class Blip2QFormerModelTester:
max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range,
bos_token_id=self.bos_token_id,
use_qformer_text_input=self.use_qformer_text_input,
)
@@ -489,7 +501,7 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_load_vision_qformer_text_config(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# Save Blip2Config and check if we can load Blip2VisionConfig from it
with tempfile.TemporaryDirectory() as tmp_dir_name:
@@ -704,6 +716,16 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
test_attention_outputs = False
test_torchscript = False
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
):
if pipeline_test_casse_name == "VisualQuestionAnsweringPipelineTests":
# Get `RuntimeError: "LayerNormKernelImpl" not implemented for 'Half'`.
return True
return False
def setUp(self):
self.model_tester = Blip2ModelTester(self)
@@ -752,7 +774,7 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_load_vision_qformer_text_config(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# Save Blip2Config and check if we can load Blip2VisionConfig from it
with tempfile.TemporaryDirectory() as tmp_dir_name:
@@ -840,6 +862,549 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
)
class Blip2TextModelWithProjectionTester:
def __init__(self, parent, vision_kwargs=None, qformer_kwargs=None, is_training=True):
if vision_kwargs is None:
vision_kwargs = {}
if qformer_kwargs is None:
qformer_kwargs = {"use_qformer_text_input": True}
self.parent = parent
self.vision_model_tester = Blip2VisionModelTester(parent, **vision_kwargs)
self.qformer_model_tester = Blip2QFormerModelTester(parent, **qformer_kwargs)
self.is_training = is_training
self.batch_size = self.vision_model_tester.batch_size # need bs for batching_equivalence test
def get_config(self):
return Blip2Config.from_vision_qformer_text_configs(
vision_config=self.vision_model_tester.get_config(),
qformer_config=self.qformer_model_tester.get_config(),
)
def prepare_config_and_inputs(self):
_, input_ids, attention_mask = self.qformer_model_tester.prepare_config_and_inputs()
config = self.get_config()
return config, input_ids, attention_mask
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, attention_mask = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
return config, inputs_dict
def create_and_check_model(self, config, input_ids, attention_mask):
model = Blip2TextModelWithProjection(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(input_ids, attention_mask=attention_mask, output_attentions=True, output_hidden_states=True)
self.parent.assertEqual(
result.last_hidden_state.shape,
(self.vision_model_tester.batch_size, input_ids.shape[1], self.qformer_model_tester.hidden_size),
)
self.parent.assertEqual(
result.text_embeds.shape,
(
self.vision_model_tester.batch_size,
input_ids.shape[1],
config.image_text_hidden_size,
),
)
with torch.no_grad():
result2 = model(
input_ids,
attention_mask=attention_mask,
return_dict=not config.use_return_dict,
output_attentions=True,
output_hidden_states=True,
)
self.parent.assertTrue(torch.allclose(result.text_embeds, result2[0]))
self.parent.assertTrue(torch.allclose(result.last_hidden_state, result2[1]))
self.parent.assertTrue(torch.allclose(result.hidden_states[0], result2[2][0]))
self.parent.assertTrue(torch.allclose(result.hidden_states[1], result2[2][1]))
self.parent.assertTrue(torch.allclose(result.attentions[0], result2[3][0]))
self.parent.assertTrue(torch.allclose(result.attentions[1], result2[3][1]))
@require_torch
class Blip2TextModelWithProjectionTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (Blip2TextModelWithProjection,) if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_head_masking = False
test_resize_embeddings = False
test_attention_outputs = False
test_torchscript = False
def setUp(self):
self.model_tester = Blip2TextModelWithProjectionTester(self)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="Training is not yet supported")
def test_training(self):
pass
@unittest.skip(reason="Training is not yet supported")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="Hidden_states is tested in individual model tests")
def test_hidden_states_output(self):
pass
@unittest.skip(reason="Blip2TextModelWithProjection does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Blip2TextModelWithProjection does not support input and output embeddings")
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="Retain_grad is tested in individual model tests")
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip(reason="Blip2TextModelWithProjection does not have input/output embeddings")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="Blip2TextModelWithProjection has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip(reason="Blip2TextModelWithProjection has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_to_base(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["input_ids", "attention_mask", "position_ids"]
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
@slow
@require_torch_gpu
def test_model_from_pretrained(self):
model_name = "Salesforce/blip2-itm-vit-g"
model = Blip2TextModelWithProjection.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertTrue(hasattr(model, "text_projection"))
_, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
self.assertEqual(
outputs.text_embeds.shape,
(
self.model_tester.qformer_model_tester.batch_size,
input_ids.shape[1],
model.config.image_text_hidden_size,
),
)
class Blip2VisionModelWithProjectionTester:
def __init__(self, parent, vision_kwargs=None, qformer_kwargs=None, is_training=True):
if vision_kwargs is None:
vision_kwargs = {}
if qformer_kwargs is None:
qformer_kwargs = {"use_qformer_text_input": True}
self.parent = parent
self.vision_model_tester = Blip2VisionModelTester(parent, **vision_kwargs)
self.qformer_model_tester = Blip2QFormerModelTester(parent, **qformer_kwargs)
self.is_training = is_training
self.num_hidden_layers = self.vision_model_tester.num_hidden_layers
self.num_attention_heads = self.vision_model_tester.num_attention_heads
self.seq_length = self.vision_model_tester.seq_length
self.hidden_size = self.vision_model_tester.hidden_size
self.batch_size = self.vision_model_tester.batch_size # need bs for batching_equivalence test
def get_config(self):
return Blip2Config.from_vision_qformer_text_configs(
vision_config=self.vision_model_tester.get_config(),
qformer_config=self.qformer_model_tester.get_config(),
)
def prepare_config_and_inputs(self):
_, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
config = self.get_config()
return config, pixel_values
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
def create_and_check_model(self, config, pixel_values):
model = Blip2VisionModelWithProjection(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(pixel_values, output_attentions=True, output_hidden_states=True)
self.parent.assertEqual(
result.last_hidden_state.shape,
(
self.vision_model_tester.batch_size,
self.vision_model_tester.seq_length,
self.qformer_model_tester.hidden_size,
),
)
self.parent.assertEqual(
result.image_embeds.shape,
(
self.vision_model_tester.batch_size,
config.vision_config.hidden_size,
config.image_text_hidden_size,
),
)
with torch.no_grad():
result2 = model(
pixel_values,
return_dict=not config.use_return_dict,
output_attentions=True,
output_hidden_states=True,
)
self.parent.assertTrue(torch.allclose(result.image_embeds, result2[0]))
self.parent.assertTrue(torch.allclose(result.last_hidden_state, result2[1]))
self.parent.assertTrue(torch.allclose(result.hidden_states[0], result2[2][0]))
self.parent.assertTrue(torch.allclose(result.hidden_states[1], result2[2][1]))
self.parent.assertTrue(torch.allclose(result.attentions[0], result2[3][0]))
self.parent.assertTrue(torch.allclose(result.attentions[1], result2[3][1]))
@require_torch
class Blip2VisionModelWithProjectionTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (Blip2VisionModelWithProjection,) if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_head_masking = False
test_resize_embeddings = False
test_torchscript = False
def setUp(self):
self.model_tester = Blip2VisionModelWithProjectionTester(self)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="Training is not yet supported")
def test_training(self):
pass
@unittest.skip(reason="Training is not yet supported")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="Training is not yet supported")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="Training is not yet supported")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Blip2VisionModelWithProjection does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Blip2VisionModelWithProjection does not support input and output embeddings")
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="Retain_grad is tested in individual model tests")
def test_retain_grad_hidden_states_attentions(self):
pass
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear))
@unittest.skip(reason="Blip2VisionModelWithProjection has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip(reason="Blip2VisionModelWithProjection has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_to_base(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
@slow
@require_torch_gpu
def test_model_from_pretrained(self):
model_name = "Salesforce/blip2-itm-vit-g"
model = Blip2VisionModelWithProjection.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertTrue(hasattr(model, "vision_projection"))
_, pixel_values = self.model_tester.prepare_config_and_inputs()
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(pixel_values=pixel_values)
self.assertEqual(
outputs.image_embeds.shape,
(
self.model_tester.vision_model_tester.batch_size,
model.config.num_query_tokens,
model.config.image_text_hidden_size,
),
)
class Blip2TextRetrievalModelTester:
def __init__(self, parent, vision_kwargs=None, qformer_kwargs=None, is_training=True):
if vision_kwargs is None:
vision_kwargs = {}
if qformer_kwargs is None:
qformer_kwargs = {"use_qformer_text_input": True}
self.parent = parent
self.vision_model_tester = Blip2VisionModelTester(parent, **vision_kwargs)
self.qformer_model_tester = Blip2QFormerModelTester(parent, **qformer_kwargs)
self.is_training = is_training
self.batch_size = self.vision_model_tester.batch_size # need bs for batching_equivalence test
def get_config(self):
return Blip2Config.from_vision_qformer_text_configs(
vision_config=self.vision_model_tester.get_config(),
qformer_config=self.qformer_model_tester.get_config(),
)
def prepare_config_and_inputs(self):
_, input_ids, attention_mask = self.qformer_model_tester.prepare_config_and_inputs()
_, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
config = self.get_config()
return config, input_ids, attention_mask, pixel_values
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
model = Blip2ForImageTextRetrieval(config).to(torch_device).eval()
with torch.no_grad():
result = model(pixel_values, input_ids, attention_mask, use_image_text_matching_head=True)
self.parent.assertEqual(
result.logits_per_image.shape,
(self.vision_model_tester.batch_size, 2),
)
with torch.no_grad():
result = model(pixel_values, input_ids, attention_mask)
self.parent.assertEqual(
result.logits_per_image.shape,
(self.vision_model_tester.batch_size, self.qformer_model_tester.batch_size),
)
self.parent.assertEqual(
result.logits_per_text.shape, (self.qformer_model_tester.batch_size, self.vision_model_tester.batch_size)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, attention_mask, pixel_values = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
}
return config, inputs_dict
@require_torch
class Blip2TextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (Blip2ForImageTextRetrieval,) if is_torch_available() else ()
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_attention_outputs = False
test_torchscript = False
def setUp(self):
self.model_tester = Blip2TextRetrievalModelTester(self)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="Hidden_states is tested in individual model tests")
def test_hidden_states_output(self):
pass
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Blip2ForImageTextRetrieval does not support input and output embeddings")
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="Retain_grad is tested in individual model tests")
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip(reason="Blip2Model does not have input/output embeddings")
def test_model_common_attributes(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values", "input_ids", "attention_mask"]
expected_arg_names.extend(
["use_image_text_matching_head"] if "use_image_text_matching_head" in arg_names else []
)
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
def test_load_vision_qformer_text_config(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# Save Blip2Config and check if we can load Blip2VisionConfig from it
with tempfile.TemporaryDirectory() as tmp_dir_name:
config.save_pretrained(tmp_dir_name)
vision_config = Blip2VisionConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
# Save Blip2Config and check if we can load Blip2QFormerConfig from it
with tempfile.TemporaryDirectory() as tmp_dir_name:
config.save_pretrained(tmp_dir_name)
qformer_config = Blip2QFormerConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.qformer_config.to_dict(), qformer_config.to_dict())
@slow
@require_torch_gpu
def test_model_from_pretrained(self):
model_name = "Salesforce/blip2-itm-vit-g"
model = Blip2ForImageTextRetrieval.from_pretrained(model_name)
self.assertIsNotNone(model)
_, input_ids, attention_mask, pixel_values = self.model_tester.prepare_config_and_inputs()
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
use_image_text_matching_head=True,
)
self.assertEqual(outputs.logits_per_image.shape, (self.model_tester.qformer_model_tester.batch_size, 2))
with torch.no_grad():
outputs = model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
)
self.assertEqual(
outputs.logits_per_image.shape,
(self.model_tester.vision_model_tester.batch_size, self.model_tester.qformer_model_tester.batch_size),
)
@unittest.skip(reason="Training is not yet supported")
def test_training(self):
pass
@unittest.skip(reason="Training is not yet supported")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="Training is not yet supported")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="Training is not yet supported")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if param.requires_grad:
# check if `logit_scale` is initilized as per the original implementation
if name == "logit_scale":
self.assertAlmostEqual(
param.data.item(),
np.log(1 / 0.07),
delta=1e-3,
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
elif name == "temp":
self.assertAlmostEqual(
param.data.item(),
0.07,
delta=1e-3,
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
else:
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
# We will verify our results on an image of cute cats
def prepare_img():
url = "https://huggingface.co/hf-internal-testing/blip-test-image/resolve/main/demo.jpg"
@@ -984,7 +1549,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
prompt = "Question: which city is this? Answer:"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(0, dtype=torch.float16)
predictions = model.generate(**inputs)
predictions = model.generate(**inputs, max_new_tokens=11)
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
# Test output
@@ -1063,3 +1628,93 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
generated_text_expanded = processor.batch_decode(predictions_expanded, skip_special_tokens=True)[0].strip()
self.assertTrue(generated_text_expanded == generated_text)
@require_torch_gpu
def test_inference_itm(self):
model_name = "Salesforce/blip2-itm-vit-g"
processor = Blip2Processor.from_pretrained(model_name)
model = Blip2ForImageTextRetrieval.from_pretrained(model_name).to(torch_device)
image = prepare_img()
text = "A woman and her dog sitting in a beach"
inputs = processor(images=image, text=text, return_tensors="pt").to(torch_device)
# forward pass
out_itm = model(**inputs, use_image_text_matching_head=True)
out = model(**inputs)
# verify
expected_scores = torch.Tensor([[0.0238, 0.9762]])
self.assertTrue(torch.allclose(torch.nn.Softmax()(out_itm[0].cpu()), expected_scores, rtol=1e-3, atol=1e-3))
self.assertTrue(torch.allclose(out[0].cpu(), torch.Tensor([[0.4406]]), rtol=1e-3, atol=1e-3))
@require_torch_gpu
@require_torch_fp16
def test_inference_itm_fp16(self):
model_name = "Salesforce/blip2-itm-vit-g"
processor = Blip2Processor.from_pretrained(model_name)
model = Blip2ForImageTextRetrieval.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device)
image = prepare_img()
text = "A woman and her dog sitting in a beach"
inputs = processor(images=image, text=text, return_tensors="pt").to(torch_device, dtype=torch.float16)
# forward pass
out_itm = model(**inputs, use_image_text_matching_head=True)
out = model(**inputs)
# verify
expected_scores = torch.Tensor([[0.0239, 0.9761]])
self.assertTrue(
torch.allclose(torch.nn.Softmax()(out_itm[0].cpu().float()), expected_scores, rtol=1e-3, atol=1e-3)
)
self.assertTrue(torch.allclose(out[0].cpu().float(), torch.Tensor([[0.4406]]), rtol=1e-3, atol=1e-3))
@require_torch_gpu
@require_torch_fp16
def test_inference_vision_with_projection_fp16(self):
model_name = "Salesforce/blip2-itm-vit-g"
processor = Blip2Processor.from_pretrained(model_name)
model = Blip2VisionModelWithProjection.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device)
image = prepare_img()
inputs = processor(images=image, return_tensors="pt").to(torch_device, dtype=torch.float16)
# forward pass
out = model(**inputs)
# verify
expected_image_embeds = [
-0.093994140625,
-0.075927734375,
0.031890869140625,
0.053009033203125,
0.0352783203125,
-0.01190185546875,
]
self.assertTrue(np.allclose(out.image_embeds[0][0][:6].tolist(), expected_image_embeds, atol=1e-3))
@require_torch_gpu
@require_torch_fp16
def test_inference_text_with_projection_fp16(self):
model_name = "Salesforce/blip2-itm-vit-g"
processor = Blip2Processor.from_pretrained(model_name)
model = Blip2TextModelWithProjection.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device)
inputs = processor(text="a woman sitting on the beach with a dog", padding=True, return_tensors="pt").to(
torch_device
)
# forward pass
out = model(**inputs)
# verify
expected_text_embeds = [
-0.1082763671875,
0.053192138671875,
-0.02825927734375,
0.0169830322265625,
0.08648681640625,
-0.04656982421875,
]
self.assertTrue(np.allclose(out.text_embeds[0][0][:6].tolist(), expected_text_embeds, atol=1e-3))