BlipModel: get_multimodal_features method (#30438)
* add_blip_get_multimodal_feautres * Fix docstring error * reimplement get_multimodal_features * fix error * recheck code quality * add new necessary tests
This commit is contained in:
@@ -814,6 +814,59 @@ class BlipModel(BlipPreTrainedModel):
|
|||||||
|
|
||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(BLIP_INPUTS_DOCSTRING)
|
||||||
|
def get_multimodal_features(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
multimodal_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The multimodal embeddings
|
||||||
|
obtained by applying the image embeddings to the text encoder using the cross-attention mechanism.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, BlipModel
|
||||||
|
|
||||||
|
>>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
||||||
|
|
||||||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
>>> texts = ["a photo of a cat", "a photo of a dog"]
|
||||||
|
>>> inputs = processor(images=image, text=texts, padding=True, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> multimodal_features = model.get_multimodal_features(**inputs)
|
||||||
|
```"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
vision_outputs = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
output_attentions=True,
|
||||||
|
output_hidden_states=True,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_embeds = vision_outputs[0]
|
||||||
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
|
||||||
|
|
||||||
|
text_outputs = self.text_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_hidden_states=image_embeds,
|
||||||
|
encoder_attention_mask=image_atts,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
pooled_output = text_outputs[1] # pooled_output
|
||||||
|
multimodal_features = self.text_projection(pooled_output)
|
||||||
|
|
||||||
|
return multimodal_features
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BLIP_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(BLIP_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=BlipOutput, config_class=BlipConfig)
|
@replace_return_docstrings(output_type=BlipOutput, config_class=BlipConfig)
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -582,6 +582,63 @@ class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
model = BlipModel.from_pretrained(model_name)
|
model = BlipModel.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
def test_get_image_features(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
keys_to_pop = ["input_ids", "attention_mask", "return_loss"]
|
||||||
|
|
||||||
|
for key in keys_to_pop:
|
||||||
|
inputs_dict.pop(key)
|
||||||
|
|
||||||
|
model = BlipModel(config).to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
image_features = model.get_image_features(**inputs_dict)
|
||||||
|
self.assertEqual(
|
||||||
|
image_features.shape,
|
||||||
|
(
|
||||||
|
self.model_tester.batch_size,
|
||||||
|
model.projection_dim,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_text_features(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
keys_to_pop = ["pixel_values", "return_loss"]
|
||||||
|
|
||||||
|
for key in keys_to_pop:
|
||||||
|
inputs_dict.pop(key)
|
||||||
|
|
||||||
|
model = BlipModel(config).to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
text_features = model.get_text_features(**inputs_dict)
|
||||||
|
self.assertEqual(
|
||||||
|
text_features.shape,
|
||||||
|
(
|
||||||
|
self.model_tester.batch_size,
|
||||||
|
model.projection_dim,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_multimodal_features(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
keys_to_pop = ["return_loss"]
|
||||||
|
|
||||||
|
for key in keys_to_pop:
|
||||||
|
inputs_dict.pop(key)
|
||||||
|
|
||||||
|
model = BlipModel(config).to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
multimodal_features = model.get_multimodal_features(**inputs_dict)
|
||||||
|
self.assertEqual(
|
||||||
|
multimodal_features.shape,
|
||||||
|
(
|
||||||
|
self.model_tester.batch_size,
|
||||||
|
model.projection_dim,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def test_pt_tf_model_equivalence(self):
|
def test_pt_tf_model_equivalence(self):
|
||||||
super().test_pt_tf_model_equivalence(allow_missing_keys=True)
|
super().test_pt_tf_model_equivalence(allow_missing_keys=True)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user