Add Blip2 model in VQA pipeline (#25532)
* Add Blip2 model in VQA pipeline * use require_torch_gpu for test_large_model_pt_blip2 * use can_generate in vqa pipeline * test Blip2ForConditionalGeneration using float16 * remove custom can_generate from Blip2ForConditionalGeneration
This commit is contained in:
@@ -839,6 +839,7 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
|||||||
|
|
||||||
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
|
("blip-2", "Blip2ForConditionalGeneration"),
|
||||||
("vilt", "ViltForQuestionAnswering"),
|
("vilt", "ViltForQuestionAnswering"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -124,19 +124,28 @@ class VisualQuestionAnsweringPipeline(Pipeline):
|
|||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def _forward(self, model_inputs):
|
def _forward(self, model_inputs):
|
||||||
model_outputs = self.model(**model_inputs)
|
if self.model.can_generate():
|
||||||
|
model_outputs = self.model.generate(**model_inputs)
|
||||||
|
else:
|
||||||
|
model_outputs = self.model(**model_inputs)
|
||||||
return model_outputs
|
return model_outputs
|
||||||
|
|
||||||
def postprocess(self, model_outputs, top_k=5):
|
def postprocess(self, model_outputs, top_k=5):
|
||||||
if top_k > self.model.config.num_labels:
|
if self.model.can_generate():
|
||||||
top_k = self.model.config.num_labels
|
return [
|
||||||
|
{"answer": self.tokenizer.decode(output_ids, skip_special_tokens=True).strip()}
|
||||||
if self.framework == "pt":
|
for output_ids in model_outputs
|
||||||
probs = model_outputs.logits.sigmoid()[0]
|
]
|
||||||
scores, ids = probs.topk(top_k)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported framework: {self.framework}")
|
if top_k > self.model.config.num_labels:
|
||||||
|
top_k = self.model.config.num_labels
|
||||||
|
|
||||||
scores = scores.tolist()
|
if self.framework == "pt":
|
||||||
ids = ids.tolist()
|
probs = model_outputs.logits.sigmoid()[0]
|
||||||
return [{"score": score, "answer": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
|
scores, ids = probs.topk(top_k)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported framework: {self.framework}")
|
||||||
|
|
||||||
|
scores = scores.tolist()
|
||||||
|
ids = ids.tolist()
|
||||||
|
return [{"score": score, "answer": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
|
||||||
|
|||||||
@@ -18,9 +18,11 @@ from transformers import MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING, is_vision_
|
|||||||
from transformers.pipelines import pipeline
|
from transformers.pipelines import pipeline
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_pipeline_test,
|
is_pipeline_test,
|
||||||
|
is_torch_available,
|
||||||
nested_simplify,
|
nested_simplify,
|
||||||
require_tf,
|
require_tf,
|
||||||
require_torch,
|
require_torch,
|
||||||
|
require_torch_gpu,
|
||||||
require_vision,
|
require_vision,
|
||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
@@ -28,6 +30,10 @@ from transformers.testing_utils import (
|
|||||||
from .test_pipelines_common import ANY
|
from .test_pipelines_common import ANY
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
else:
|
else:
|
||||||
@@ -84,6 +90,37 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
|
|||||||
outputs, [{"score": ANY(float), "answer": ANY(str)}, {"score": ANY(float), "answer": ANY(str)}]
|
outputs, [{"score": ANY(float), "answer": ANY(str)}, {"score": ANY(float), "answer": ANY(str)}]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_small_model_pt_blip2(self):
|
||||||
|
vqa_pipeline = pipeline(
|
||||||
|
"visual-question-answering", model="hf-internal-testing/tiny-random-Blip2ForConditionalGeneration"
|
||||||
|
)
|
||||||
|
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||||
|
question = "How many cats are there?"
|
||||||
|
|
||||||
|
outputs = vqa_pipeline(image=image, question=question)
|
||||||
|
self.assertEqual(outputs, [{"answer": ANY(str)}])
|
||||||
|
|
||||||
|
outputs = vqa_pipeline({"image": image, "question": question})
|
||||||
|
self.assertEqual(outputs, [{"answer": ANY(str)}])
|
||||||
|
|
||||||
|
outputs = vqa_pipeline([{"image": image, "question": question}, {"image": image, "question": question}])
|
||||||
|
self.assertEqual(outputs, [[{"answer": ANY(str)}]] * 2)
|
||||||
|
|
||||||
|
vqa_pipeline = pipeline(
|
||||||
|
"visual-question-answering",
|
||||||
|
model="hf-internal-testing/tiny-random-Blip2ForConditionalGeneration",
|
||||||
|
model_kwargs={"torch_dtype": torch.float16},
|
||||||
|
device=0,
|
||||||
|
)
|
||||||
|
self.assertEqual(vqa_pipeline.model.device, torch.device(0))
|
||||||
|
self.assertEqual(vqa_pipeline.model.language_model.dtype, torch.float16)
|
||||||
|
self.assertEqual(vqa_pipeline.model.vision_model.dtype, torch.float16)
|
||||||
|
|
||||||
|
outputs = vqa_pipeline(image=image, question=question)
|
||||||
|
self.assertEqual(outputs, [{"answer": ANY(str)}])
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_large_model_pt(self):
|
def test_large_model_pt(self):
|
||||||
@@ -109,6 +146,31 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
|
|||||||
[[{"score": 0.8799, "answer": "2"}, {"score": 0.296, "answer": "1"}]] * 2,
|
[[{"score": 0.8799, "answer": "2"}, {"score": 0.296, "answer": "1"}]] * 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_large_model_pt_blip2(self):
|
||||||
|
vqa_pipeline = pipeline(
|
||||||
|
"visual-question-answering",
|
||||||
|
model="Salesforce/blip2-opt-2.7b",
|
||||||
|
model_kwargs={"torch_dtype": torch.float16},
|
||||||
|
device=0,
|
||||||
|
)
|
||||||
|
self.assertEqual(vqa_pipeline.model.device, torch.device(0))
|
||||||
|
self.assertEqual(vqa_pipeline.model.language_model.dtype, torch.float16)
|
||||||
|
|
||||||
|
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||||
|
question = "Question: how many cats are there? Answer:"
|
||||||
|
|
||||||
|
outputs = vqa_pipeline(image=image, question=question)
|
||||||
|
self.assertEqual(outputs, [{"answer": "two"}])
|
||||||
|
|
||||||
|
outputs = vqa_pipeline({"image": image, "question": question})
|
||||||
|
self.assertEqual(outputs, [{"answer": "two"}])
|
||||||
|
|
||||||
|
outputs = vqa_pipeline([{"image": image, "question": question}, {"image": image, "question": question}])
|
||||||
|
self.assertEqual(outputs, [[{"answer": "two"}]] * 2)
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@unittest.skip("Visual question answering not implemented in TF")
|
@unittest.skip("Visual question answering not implemented in TF")
|
||||||
def test_small_model_tf(self):
|
def test_small_model_tf(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user