[Blip2] Add Blip2Model (#21817)

* add v1

* add `Blip2Model`

- add relevant functions
- add tests
- add on automapping

* fix docs

* fix doctest
This commit is contained in:
Younes Belkada
2023-02-28 15:42:55 +01:00
committed by GitHub
parent ae9230af40
commit b8de7e448e
7 changed files with 441 additions and 3 deletions

View File

@@ -40,7 +40,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import Blip2ForConditionalGeneration, Blip2VisionModel
from transformers import Blip2ForConditionalGeneration, Blip2Model, Blip2VisionModel
from transformers.models.blip_2.modeling_blip_2 import BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST
@@ -664,8 +664,8 @@ class Blip2ForConditionalGenerationModelTester:
@require_torch
class Blip2ForConditionalGenerationTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (Blip2ForConditionalGeneration,) if is_torch_available() else ()
class Blip2ModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (Blip2ForConditionalGeneration, Blip2Model) if is_torch_available() else ()
fx_compatible = False
test_head_masking = False
test_pruning = False
@@ -737,6 +737,56 @@ class Blip2ForConditionalGenerationTest(ModelTesterMixin, unittest.TestCase):
model = Blip2ForConditionalGeneration.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_get_text_features(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict = {
"input_ids": torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]).to(torch_device),
"attention_mask": torch.LongTensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]).to(torch_device),
"decoder_input_ids": torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]).to(torch_device),
}
model = Blip2Model(config).to(torch_device)
model.eval()
text_features = model.get_text_features(**inputs_dict)
self.assertEqual(text_features[0].shape, (1, 10, config.text_config.vocab_size))
def test_get_image_features(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
keys_to_pop = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"]
for key in keys_to_pop:
inputs_dict.pop(key)
model = Blip2Model(config).to(torch_device)
model.eval()
image_features = model.get_image_features(**inputs_dict)
self.assertEqual(
image_features[0].shape,
(
self.model_tester.vision_model_tester.batch_size,
self.model_tester.vision_model_tester.seq_length,
config.vision_config.hidden_size,
),
)
def test_get_qformer_features(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
keys_to_pop = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"]
for key in keys_to_pop:
inputs_dict.pop(key)
model = Blip2Model(config).to(torch_device)
model.eval()
qformer_features = model.get_qformer_features(**inputs_dict)
self.assertEqual(
qformer_features[0].shape,
(self.model_tester.vision_model_tester.batch_size, 10, config.vision_config.hidden_size),
)
# override from common to deal with nested configurations (`vision_config`, `text_config` and `qformer_config`)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()