[Blip2] Add Blip2Model (#21817)
* add v1 * add `Blip2Model` - add relevant functions - add tests - add on automapping * fix docs * fix doctest
This commit is contained in:
@@ -40,7 +40,7 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import Blip2ForConditionalGeneration, Blip2VisionModel
|
||||
from transformers import Blip2ForConditionalGeneration, Blip2Model, Blip2VisionModel
|
||||
from transformers.models.blip_2.modeling_blip_2 import BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
@@ -664,8 +664,8 @@ class Blip2ForConditionalGenerationModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class Blip2ForConditionalGenerationTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Blip2ForConditionalGeneration,) if is_torch_available() else ()
|
||||
class Blip2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Blip2ForConditionalGeneration, Blip2Model) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
@@ -737,6 +737,56 @@ class Blip2ForConditionalGenerationTest(ModelTesterMixin, unittest.TestCase):
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_get_text_features(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
inputs_dict = {
|
||||
"input_ids": torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]).to(torch_device),
|
||||
"attention_mask": torch.LongTensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]).to(torch_device),
|
||||
"decoder_input_ids": torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]).to(torch_device),
|
||||
}
|
||||
|
||||
model = Blip2Model(config).to(torch_device)
|
||||
model.eval()
|
||||
text_features = model.get_text_features(**inputs_dict)
|
||||
self.assertEqual(text_features[0].shape, (1, 10, config.text_config.vocab_size))
|
||||
|
||||
def test_get_image_features(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
keys_to_pop = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"]
|
||||
|
||||
for key in keys_to_pop:
|
||||
inputs_dict.pop(key)
|
||||
|
||||
model = Blip2Model(config).to(torch_device)
|
||||
model.eval()
|
||||
image_features = model.get_image_features(**inputs_dict)
|
||||
self.assertEqual(
|
||||
image_features[0].shape,
|
||||
(
|
||||
self.model_tester.vision_model_tester.batch_size,
|
||||
self.model_tester.vision_model_tester.seq_length,
|
||||
config.vision_config.hidden_size,
|
||||
),
|
||||
)
|
||||
|
||||
def test_get_qformer_features(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
keys_to_pop = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"]
|
||||
|
||||
for key in keys_to_pop:
|
||||
inputs_dict.pop(key)
|
||||
|
||||
model = Blip2Model(config).to(torch_device)
|
||||
model.eval()
|
||||
qformer_features = model.get_qformer_features(**inputs_dict)
|
||||
self.assertEqual(
|
||||
qformer_features[0].shape,
|
||||
(self.model_tester.vision_model_tester.batch_size, 10, config.vision_config.hidden_size),
|
||||
)
|
||||
|
||||
# override from common to deal with nested configurations (`vision_config`, `text_config` and `qformer_config`)
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user