🚨 Add Blip2ForImageTextRetrieval (#29261)
* add Blip2ForImageTextRetrieval * use one line and remove unnecessary space in tests Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * use value from the config, rather than hardcoded * change order of params in Blip2QFormerModel.forward * update docstring * fix style * update test_inference_opt * move embeddings out of Blip2QFormerModel * remove from_vision_qformer_configs * remove autocast float16 in Blip2QFormerModel * rename fiels into vision_projection,text_projection,use_image_text_matching_head * use CLIPOutput for Blip2ImageTextMatchingModelOutput * remove past_key_values_length from Blip2TextEmbeddings * fix small typo in the CLIPOutput docstring * add Blip2ForImageTextRetrieval to Zero Shot Image Classification mapping * update docstring and add require_torch_fp16 * rollback test_inference_opt * use use_image_text_matching_head=True in convert * skip test_model_get_set_embeddings * fix create_rename_keys error on new itm fields * revert to do scale after dot product between "query" and "key" * fix ValueError on convert script for blip2-opt-2.7b * update org of paths to Salesforce * add is_pipeline_test_to_skip for VisualQuestionAnsweringPipelineTests * [run_slow] blip_2 * removed Blip2ForImageTextRetrieval from IGNORE_NON_AUTO_CONFIGURED * fix docstring of Blip2ImageTextMatchingModelOutput * [run_slow] blip_2 * fix multi-gpu tests * [run_slow] blip_2 * [run_slow] blip_2 --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -24,6 +24,8 @@ import requests
|
||||
from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig
|
||||
from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_fp16,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
require_vision,
|
||||
slow,
|
||||
@@ -47,7 +49,14 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import Blip2ForConditionalGeneration, Blip2Model, Blip2VisionModel
|
||||
from transformers import (
|
||||
Blip2ForConditionalGeneration,
|
||||
Blip2ForImageTextRetrieval,
|
||||
Blip2Model,
|
||||
Blip2TextModelWithProjection,
|
||||
Blip2VisionModel,
|
||||
Blip2VisionModelWithProjection,
|
||||
)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
@@ -243,6 +252,7 @@ class Blip2QFormerModelTester:
|
||||
initializer_range=0.02,
|
||||
bos_token_id=0,
|
||||
scope=None,
|
||||
use_qformer_text_input=False,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -262,6 +272,7 @@ class Blip2QFormerModelTester:
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
self.bos_token_id = bos_token_id
|
||||
self.use_qformer_text_input = use_qformer_text_input
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@@ -294,6 +305,7 @@ class Blip2QFormerModelTester:
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
initializer_range=self.initializer_range,
|
||||
bos_token_id=self.bos_token_id,
|
||||
use_qformer_text_input=self.use_qformer_text_input,
|
||||
)
|
||||
|
||||
|
||||
@@ -489,7 +501,7 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_load_vision_qformer_text_config(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Save Blip2Config and check if we can load Blip2VisionConfig from it
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
@@ -704,6 +716,16 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
|
||||
test_attention_outputs = False
|
||||
test_torchscript = False
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
):
|
||||
if pipeline_test_casse_name == "VisualQuestionAnsweringPipelineTests":
|
||||
# Get `RuntimeError: "LayerNormKernelImpl" not implemented for 'Half'`.
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Blip2ModelTester(self)
|
||||
|
||||
@@ -752,7 +774,7 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_load_vision_qformer_text_config(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Save Blip2Config and check if we can load Blip2VisionConfig from it
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
@@ -840,6 +862,549 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
|
||||
)
|
||||
|
||||
|
||||
class Blip2TextModelWithProjectionTester:
|
||||
def __init__(self, parent, vision_kwargs=None, qformer_kwargs=None, is_training=True):
|
||||
if vision_kwargs is None:
|
||||
vision_kwargs = {}
|
||||
if qformer_kwargs is None:
|
||||
qformer_kwargs = {"use_qformer_text_input": True}
|
||||
|
||||
self.parent = parent
|
||||
self.vision_model_tester = Blip2VisionModelTester(parent, **vision_kwargs)
|
||||
self.qformer_model_tester = Blip2QFormerModelTester(parent, **qformer_kwargs)
|
||||
self.is_training = is_training
|
||||
self.batch_size = self.vision_model_tester.batch_size # need bs for batching_equivalence test
|
||||
|
||||
def get_config(self):
|
||||
return Blip2Config.from_vision_qformer_text_configs(
|
||||
vision_config=self.vision_model_tester.get_config(),
|
||||
qformer_config=self.qformer_model_tester.get_config(),
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
_, input_ids, attention_mask = self.qformer_model_tester.prepare_config_and_inputs()
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, attention_mask
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, attention_mask = config_and_inputs
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
def create_and_check_model(self, config, input_ids, attention_mask):
|
||||
model = Blip2TextModelWithProjection(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
result = model(input_ids, attention_mask=attention_mask, output_attentions=True, output_hidden_states=True)
|
||||
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape,
|
||||
(self.vision_model_tester.batch_size, input_ids.shape[1], self.qformer_model_tester.hidden_size),
|
||||
)
|
||||
self.parent.assertEqual(
|
||||
result.text_embeds.shape,
|
||||
(
|
||||
self.vision_model_tester.batch_size,
|
||||
input_ids.shape[1],
|
||||
config.image_text_hidden_size,
|
||||
),
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
result2 = model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=not config.use_return_dict,
|
||||
output_attentions=True,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
self.parent.assertTrue(torch.allclose(result.text_embeds, result2[0]))
|
||||
self.parent.assertTrue(torch.allclose(result.last_hidden_state, result2[1]))
|
||||
self.parent.assertTrue(torch.allclose(result.hidden_states[0], result2[2][0]))
|
||||
self.parent.assertTrue(torch.allclose(result.hidden_states[1], result2[2][1]))
|
||||
self.parent.assertTrue(torch.allclose(result.attentions[0], result2[3][0]))
|
||||
self.parent.assertTrue(torch.allclose(result.attentions[1], result2[3][1]))
|
||||
|
||||
|
||||
@require_torch
|
||||
class Blip2TextModelWithProjectionTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Blip2TextModelWithProjection,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
|
||||
test_resize_embeddings = False
|
||||
test_attention_outputs = False
|
||||
test_torchscript = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Blip2TextModelWithProjectionTester(self)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Training is not yet supported")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Training is not yet supported")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Blip2TextModelWithProjection does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Blip2TextModelWithProjection does not support input and output embeddings")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Blip2TextModelWithProjection does not have input/output embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Blip2TextModelWithProjection has no base class and is not available in MODEL_MAPPING")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Blip2TextModelWithProjection has no base class and is not available in MODEL_MAPPING")
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
pass
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["input_ids", "attention_mask", "position_ids"]
|
||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_model_from_pretrained(self):
|
||||
model_name = "Salesforce/blip2-itm-vit-g"
|
||||
model = Blip2TextModelWithProjection.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertTrue(hasattr(model, "text_projection"))
|
||||
|
||||
_, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
self.assertEqual(
|
||||
outputs.text_embeds.shape,
|
||||
(
|
||||
self.model_tester.qformer_model_tester.batch_size,
|
||||
input_ids.shape[1],
|
||||
model.config.image_text_hidden_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Blip2VisionModelWithProjectionTester:
|
||||
def __init__(self, parent, vision_kwargs=None, qformer_kwargs=None, is_training=True):
|
||||
if vision_kwargs is None:
|
||||
vision_kwargs = {}
|
||||
if qformer_kwargs is None:
|
||||
qformer_kwargs = {"use_qformer_text_input": True}
|
||||
|
||||
self.parent = parent
|
||||
self.vision_model_tester = Blip2VisionModelTester(parent, **vision_kwargs)
|
||||
self.qformer_model_tester = Blip2QFormerModelTester(parent, **qformer_kwargs)
|
||||
self.is_training = is_training
|
||||
self.num_hidden_layers = self.vision_model_tester.num_hidden_layers
|
||||
self.num_attention_heads = self.vision_model_tester.num_attention_heads
|
||||
self.seq_length = self.vision_model_tester.seq_length
|
||||
self.hidden_size = self.vision_model_tester.hidden_size
|
||||
self.batch_size = self.vision_model_tester.batch_size # need bs for batching_equivalence test
|
||||
|
||||
def get_config(self):
|
||||
return Blip2Config.from_vision_qformer_text_configs(
|
||||
vision_config=self.vision_model_tester.get_config(),
|
||||
qformer_config=self.qformer_model_tester.get_config(),
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
_, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
def create_and_check_model(self, config, pixel_values):
|
||||
model = Blip2VisionModelWithProjection(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
result = model(pixel_values, output_attentions=True, output_hidden_states=True)
|
||||
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape,
|
||||
(
|
||||
self.vision_model_tester.batch_size,
|
||||
self.vision_model_tester.seq_length,
|
||||
self.qformer_model_tester.hidden_size,
|
||||
),
|
||||
)
|
||||
self.parent.assertEqual(
|
||||
result.image_embeds.shape,
|
||||
(
|
||||
self.vision_model_tester.batch_size,
|
||||
config.vision_config.hidden_size,
|
||||
config.image_text_hidden_size,
|
||||
),
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
result2 = model(
|
||||
pixel_values,
|
||||
return_dict=not config.use_return_dict,
|
||||
output_attentions=True,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
self.parent.assertTrue(torch.allclose(result.image_embeds, result2[0]))
|
||||
self.parent.assertTrue(torch.allclose(result.last_hidden_state, result2[1]))
|
||||
self.parent.assertTrue(torch.allclose(result.hidden_states[0], result2[2][0]))
|
||||
self.parent.assertTrue(torch.allclose(result.hidden_states[1], result2[2][1]))
|
||||
self.parent.assertTrue(torch.allclose(result.attentions[0], result2[3][0]))
|
||||
self.parent.assertTrue(torch.allclose(result.attentions[1], result2[3][1]))
|
||||
|
||||
|
||||
@require_torch
|
||||
class Blip2VisionModelWithProjectionTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Blip2VisionModelWithProjection,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
|
||||
test_resize_embeddings = False
|
||||
test_torchscript = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Blip2VisionModelWithProjectionTester(self)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Training is not yet supported")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Training is not yet supported")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Training is not yet supported")
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Training is not yet supported")
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Blip2VisionModelWithProjection does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Blip2VisionModelWithProjection does not support input and output embeddings")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
|
||||
x = model.get_output_embeddings()
|
||||
self.assertTrue(x is None or isinstance(x, nn.Linear))
|
||||
|
||||
@unittest.skip(reason="Blip2VisionModelWithProjection has no base class and is not available in MODEL_MAPPING")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Blip2VisionModelWithProjection has no base class and is not available in MODEL_MAPPING")
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
pass
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_model_from_pretrained(self):
|
||||
model_name = "Salesforce/blip2-itm-vit-g"
|
||||
model = Blip2VisionModelWithProjection.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertTrue(hasattr(model, "vision_projection"))
|
||||
|
||||
_, pixel_values = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values=pixel_values)
|
||||
|
||||
self.assertEqual(
|
||||
outputs.image_embeds.shape,
|
||||
(
|
||||
self.model_tester.vision_model_tester.batch_size,
|
||||
model.config.num_query_tokens,
|
||||
model.config.image_text_hidden_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Blip2TextRetrievalModelTester:
|
||||
def __init__(self, parent, vision_kwargs=None, qformer_kwargs=None, is_training=True):
|
||||
if vision_kwargs is None:
|
||||
vision_kwargs = {}
|
||||
if qformer_kwargs is None:
|
||||
qformer_kwargs = {"use_qformer_text_input": True}
|
||||
|
||||
self.parent = parent
|
||||
self.vision_model_tester = Blip2VisionModelTester(parent, **vision_kwargs)
|
||||
self.qformer_model_tester = Blip2QFormerModelTester(parent, **qformer_kwargs)
|
||||
self.is_training = is_training
|
||||
self.batch_size = self.vision_model_tester.batch_size # need bs for batching_equivalence test
|
||||
|
||||
def get_config(self):
|
||||
return Blip2Config.from_vision_qformer_text_configs(
|
||||
vision_config=self.vision_model_tester.get_config(),
|
||||
qformer_config=self.qformer_model_tester.get_config(),
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
_, input_ids, attention_mask = self.qformer_model_tester.prepare_config_and_inputs()
|
||||
_, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, attention_mask, pixel_values
|
||||
|
||||
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
|
||||
model = Blip2ForImageTextRetrieval(config).to(torch_device).eval()
|
||||
with torch.no_grad():
|
||||
result = model(pixel_values, input_ids, attention_mask, use_image_text_matching_head=True)
|
||||
|
||||
self.parent.assertEqual(
|
||||
result.logits_per_image.shape,
|
||||
(self.vision_model_tester.batch_size, 2),
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
result = model(pixel_values, input_ids, attention_mask)
|
||||
|
||||
self.parent.assertEqual(
|
||||
result.logits_per_image.shape,
|
||||
(self.vision_model_tester.batch_size, self.qformer_model_tester.batch_size),
|
||||
)
|
||||
self.parent.assertEqual(
|
||||
result.logits_per_text.shape, (self.qformer_model_tester.batch_size, self.vision_model_tester.batch_size)
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class Blip2TextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Blip2ForImageTextRetrieval,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_attention_outputs = False
|
||||
test_torchscript = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Blip2TextRetrievalModelTester(self)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Blip2ForImageTextRetrieval does not support input and output embeddings")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Blip2Model does not have input/output embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values", "input_ids", "attention_mask"]
|
||||
expected_arg_names.extend(
|
||||
["use_image_text_matching_head"] if "use_image_text_matching_head" in arg_names else []
|
||||
)
|
||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||
|
||||
def test_load_vision_qformer_text_config(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Save Blip2Config and check if we can load Blip2VisionConfig from it
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
config.save_pretrained(tmp_dir_name)
|
||||
vision_config = Blip2VisionConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
|
||||
|
||||
# Save Blip2Config and check if we can load Blip2QFormerConfig from it
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
config.save_pretrained(tmp_dir_name)
|
||||
qformer_config = Blip2QFormerConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.qformer_config.to_dict(), qformer_config.to_dict())
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_model_from_pretrained(self):
|
||||
model_name = "Salesforce/blip2-itm-vit-g"
|
||||
model = Blip2ForImageTextRetrieval.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
_, input_ids, attention_mask, pixel_values = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
pixel_values=pixel_values,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
use_image_text_matching_head=True,
|
||||
)
|
||||
self.assertEqual(outputs.logits_per_image.shape, (self.model_tester.qformer_model_tester.batch_size, 2))
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
pixel_values=pixel_values,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs.logits_per_image.shape,
|
||||
(self.model_tester.vision_model_tester.batch_size, self.model_tester.qformer_model_tester.batch_size),
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Training is not yet supported")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Training is not yet supported")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Training is not yet supported")
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Training is not yet supported")
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
# check if `logit_scale` is initilized as per the original implementation
|
||||
if name == "logit_scale":
|
||||
self.assertAlmostEqual(
|
||||
param.data.item(),
|
||||
np.log(1 / 0.07),
|
||||
delta=1e-3,
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
elif name == "temp":
|
||||
self.assertAlmostEqual(
|
||||
param.data.item(),
|
||||
0.07,
|
||||
delta=1e-3,
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
else:
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "https://huggingface.co/hf-internal-testing/blip-test-image/resolve/main/demo.jpg"
|
||||
@@ -984,7 +1549,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
prompt = "Question: which city is this? Answer:"
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(0, dtype=torch.float16)
|
||||
|
||||
predictions = model.generate(**inputs)
|
||||
predictions = model.generate(**inputs, max_new_tokens=11)
|
||||
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
|
||||
|
||||
# Test output
|
||||
@@ -1063,3 +1628,93 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
generated_text_expanded = processor.batch_decode(predictions_expanded, skip_special_tokens=True)[0].strip()
|
||||
|
||||
self.assertTrue(generated_text_expanded == generated_text)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_inference_itm(self):
|
||||
model_name = "Salesforce/blip2-itm-vit-g"
|
||||
processor = Blip2Processor.from_pretrained(model_name)
|
||||
model = Blip2ForImageTextRetrieval.from_pretrained(model_name).to(torch_device)
|
||||
|
||||
image = prepare_img()
|
||||
text = "A woman and her dog sitting in a beach"
|
||||
inputs = processor(images=image, text=text, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
out_itm = model(**inputs, use_image_text_matching_head=True)
|
||||
out = model(**inputs)
|
||||
|
||||
# verify
|
||||
expected_scores = torch.Tensor([[0.0238, 0.9762]])
|
||||
self.assertTrue(torch.allclose(torch.nn.Softmax()(out_itm[0].cpu()), expected_scores, rtol=1e-3, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(out[0].cpu(), torch.Tensor([[0.4406]]), rtol=1e-3, atol=1e-3))
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_fp16
|
||||
def test_inference_itm_fp16(self):
|
||||
model_name = "Salesforce/blip2-itm-vit-g"
|
||||
processor = Blip2Processor.from_pretrained(model_name)
|
||||
model = Blip2ForImageTextRetrieval.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device)
|
||||
|
||||
image = prepare_img()
|
||||
text = "A woman and her dog sitting in a beach"
|
||||
inputs = processor(images=image, text=text, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||
|
||||
# forward pass
|
||||
out_itm = model(**inputs, use_image_text_matching_head=True)
|
||||
out = model(**inputs)
|
||||
|
||||
# verify
|
||||
expected_scores = torch.Tensor([[0.0239, 0.9761]])
|
||||
self.assertTrue(
|
||||
torch.allclose(torch.nn.Softmax()(out_itm[0].cpu().float()), expected_scores, rtol=1e-3, atol=1e-3)
|
||||
)
|
||||
self.assertTrue(torch.allclose(out[0].cpu().float(), torch.Tensor([[0.4406]]), rtol=1e-3, atol=1e-3))
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_fp16
|
||||
def test_inference_vision_with_projection_fp16(self):
|
||||
model_name = "Salesforce/blip2-itm-vit-g"
|
||||
processor = Blip2Processor.from_pretrained(model_name)
|
||||
model = Blip2VisionModelWithProjection.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device)
|
||||
|
||||
image = prepare_img()
|
||||
inputs = processor(images=image, return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||
|
||||
# forward pass
|
||||
out = model(**inputs)
|
||||
|
||||
# verify
|
||||
expected_image_embeds = [
|
||||
-0.093994140625,
|
||||
-0.075927734375,
|
||||
0.031890869140625,
|
||||
0.053009033203125,
|
||||
0.0352783203125,
|
||||
-0.01190185546875,
|
||||
]
|
||||
self.assertTrue(np.allclose(out.image_embeds[0][0][:6].tolist(), expected_image_embeds, atol=1e-3))
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_fp16
|
||||
def test_inference_text_with_projection_fp16(self):
|
||||
model_name = "Salesforce/blip2-itm-vit-g"
|
||||
processor = Blip2Processor.from_pretrained(model_name)
|
||||
model = Blip2TextModelWithProjection.from_pretrained(model_name, torch_dtype=torch.float16).to(torch_device)
|
||||
|
||||
inputs = processor(text="a woman sitting on the beach with a dog", padding=True, return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
# forward pass
|
||||
out = model(**inputs)
|
||||
|
||||
# verify
|
||||
expected_text_embeds = [
|
||||
-0.1082763671875,
|
||||
0.053192138671875,
|
||||
-0.02825927734375,
|
||||
0.0169830322265625,
|
||||
0.08648681640625,
|
||||
-0.04656982421875,
|
||||
]
|
||||
self.assertTrue(np.allclose(out.text_embeds[0][0][:6].tolist(), expected_text_embeds, atol=1e-3))
|
||||
|
||||
Reference in New Issue
Block a user