Add Blip2 model in VQA pipeline (#25532)
* Add Blip2 model in VQA pipeline * use require_torch_gpu for test_large_model_pt_blip2 * use can_generate in vqa pipeline * test Blip2ForConditionalGeneration using float16 * remove custom can_generate from Blip2ForConditionalGeneration
This commit is contained in:
@@ -839,6 +839,7 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
|
||||
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("blip-2", "Blip2ForConditionalGeneration"),
|
||||
("vilt", "ViltForQuestionAnswering"),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -124,10 +124,19 @@ class VisualQuestionAnsweringPipeline(Pipeline):
|
||||
return model_inputs
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
if self.model.can_generate():
|
||||
model_outputs = self.model.generate(**model_inputs)
|
||||
else:
|
||||
model_outputs = self.model(**model_inputs)
|
||||
return model_outputs
|
||||
|
||||
def postprocess(self, model_outputs, top_k=5):
|
||||
if self.model.can_generate():
|
||||
return [
|
||||
{"answer": self.tokenizer.decode(output_ids, skip_special_tokens=True).strip()}
|
||||
for output_ids in model_outputs
|
||||
]
|
||||
else:
|
||||
if top_k > self.model.config.num_labels:
|
||||
top_k = self.model.config.num_labels
|
||||
|
||||
|
||||
@@ -18,9 +18,11 @@ from transformers import MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING, is_vision_
|
||||
from transformers.pipelines import pipeline
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
is_torch_available,
|
||||
nested_simplify,
|
||||
require_tf,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_vision,
|
||||
slow,
|
||||
)
|
||||
@@ -28,6 +30,10 @@ from transformers.testing_utils import (
|
||||
from .test_pipelines_common import ANY
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
else:
|
||||
@@ -84,6 +90,37 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
|
||||
outputs, [{"score": ANY(float), "answer": ANY(str)}, {"score": ANY(float), "answer": ANY(str)}]
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
def test_small_model_pt_blip2(self):
|
||||
vqa_pipeline = pipeline(
|
||||
"visual-question-answering", model="hf-internal-testing/tiny-random-Blip2ForConditionalGeneration"
|
||||
)
|
||||
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||
question = "How many cats are there?"
|
||||
|
||||
outputs = vqa_pipeline(image=image, question=question)
|
||||
self.assertEqual(outputs, [{"answer": ANY(str)}])
|
||||
|
||||
outputs = vqa_pipeline({"image": image, "question": question})
|
||||
self.assertEqual(outputs, [{"answer": ANY(str)}])
|
||||
|
||||
outputs = vqa_pipeline([{"image": image, "question": question}, {"image": image, "question": question}])
|
||||
self.assertEqual(outputs, [[{"answer": ANY(str)}]] * 2)
|
||||
|
||||
vqa_pipeline = pipeline(
|
||||
"visual-question-answering",
|
||||
model="hf-internal-testing/tiny-random-Blip2ForConditionalGeneration",
|
||||
model_kwargs={"torch_dtype": torch.float16},
|
||||
device=0,
|
||||
)
|
||||
self.assertEqual(vqa_pipeline.model.device, torch.device(0))
|
||||
self.assertEqual(vqa_pipeline.model.language_model.dtype, torch.float16)
|
||||
self.assertEqual(vqa_pipeline.model.vision_model.dtype, torch.float16)
|
||||
|
||||
outputs = vqa_pipeline(image=image, question=question)
|
||||
self.assertEqual(outputs, [{"answer": ANY(str)}])
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_large_model_pt(self):
|
||||
@@ -109,6 +146,31 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
|
||||
[[{"score": 0.8799, "answer": "2"}, {"score": 0.296, "answer": "1"}]] * 2,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
def test_large_model_pt_blip2(self):
|
||||
vqa_pipeline = pipeline(
|
||||
"visual-question-answering",
|
||||
model="Salesforce/blip2-opt-2.7b",
|
||||
model_kwargs={"torch_dtype": torch.float16},
|
||||
device=0,
|
||||
)
|
||||
self.assertEqual(vqa_pipeline.model.device, torch.device(0))
|
||||
self.assertEqual(vqa_pipeline.model.language_model.dtype, torch.float16)
|
||||
|
||||
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||
question = "Question: how many cats are there? Answer:"
|
||||
|
||||
outputs = vqa_pipeline(image=image, question=question)
|
||||
self.assertEqual(outputs, [{"answer": "two"}])
|
||||
|
||||
outputs = vqa_pipeline({"image": image, "question": question})
|
||||
self.assertEqual(outputs, [{"answer": "two"}])
|
||||
|
||||
outputs = vqa_pipeline([{"image": image, "question": question}, {"image": image, "question": question}])
|
||||
self.assertEqual(outputs, [[{"answer": "two"}]] * 2)
|
||||
|
||||
@require_tf
|
||||
@unittest.skip("Visual question answering not implemented in TF")
|
||||
def test_small_model_tf(self):
|
||||
|
||||
Reference in New Issue
Block a user