From 09dc99517f5f38ee210cf1145a7b17fc99b37dac Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Wed, 30 Aug 2023 15:16:16 +0200 Subject: [PATCH] Add Blip2 model in VQA pipeline (#25532) * Add Blip2 model in VQA pipeline * use require_torch_gpu for test_large_model_pt_blip2 * use can_generate in vqa pipeline * test Blip2ForConditionalGeneration using float16 * remove custom can_generate from Blip2ForConditionalGeneration --- src/transformers/models/auto/modeling_auto.py | 1 + .../pipelines/visual_question_answering.py | 31 ++++++---- ...est_pipelines_visual_question_answering.py | 62 +++++++++++++++++++ 3 files changed, 83 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index fa9a483eb3..229251d4ef 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -839,6 +839,7 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ + ("blip-2", "Blip2ForConditionalGeneration"), ("vilt", "ViltForQuestionAnswering"), ] ) diff --git a/src/transformers/pipelines/visual_question_answering.py b/src/transformers/pipelines/visual_question_answering.py index 339a907cbb..c3bf65114f 100644 --- a/src/transformers/pipelines/visual_question_answering.py +++ b/src/transformers/pipelines/visual_question_answering.py @@ -124,19 +124,28 @@ class VisualQuestionAnsweringPipeline(Pipeline): return model_inputs def _forward(self, model_inputs): - model_outputs = self.model(**model_inputs) + if self.model.can_generate(): + model_outputs = self.model.generate(**model_inputs) + else: + model_outputs = self.model(**model_inputs) return model_outputs def postprocess(self, model_outputs, top_k=5): - if top_k > self.model.config.num_labels: - top_k = self.model.config.num_labels - - if self.framework == "pt": - probs = model_outputs.logits.sigmoid()[0] - scores, ids = probs.topk(top_k) + if self.model.can_generate(): + return [ + {"answer": self.tokenizer.decode(output_ids, skip_special_tokens=True).strip()} + for output_ids in model_outputs + ] else: - raise ValueError(f"Unsupported framework: {self.framework}") + if top_k > self.model.config.num_labels: + top_k = self.model.config.num_labels - scores = scores.tolist() - ids = ids.tolist() - return [{"score": score, "answer": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)] + if self.framework == "pt": + probs = model_outputs.logits.sigmoid()[0] + scores, ids = probs.topk(top_k) + else: + raise ValueError(f"Unsupported framework: {self.framework}") + + scores = scores.tolist() + ids = ids.tolist() + return [{"score": score, "answer": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)] diff --git a/tests/pipelines/test_pipelines_visual_question_answering.py b/tests/pipelines/test_pipelines_visual_question_answering.py index 63a5cc7097..55ad44ef8d 100644 --- a/tests/pipelines/test_pipelines_visual_question_answering.py +++ b/tests/pipelines/test_pipelines_visual_question_answering.py @@ -18,9 +18,11 @@ from transformers import MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING, is_vision_ from transformers.pipelines import pipeline from transformers.testing_utils import ( is_pipeline_test, + is_torch_available, nested_simplify, require_tf, require_torch, + require_torch_gpu, require_vision, slow, ) @@ -28,6 +30,10 @@ from transformers.testing_utils import ( from .test_pipelines_common import ANY +if is_torch_available(): + import torch + + if is_vision_available(): from PIL import Image else: @@ -84,6 +90,37 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase): outputs, [{"score": ANY(float), "answer": ANY(str)}, {"score": ANY(float), "answer": ANY(str)}] ) + @require_torch + @require_torch_gpu + def test_small_model_pt_blip2(self): + vqa_pipeline = pipeline( + "visual-question-answering", model="hf-internal-testing/tiny-random-Blip2ForConditionalGeneration" + ) + image = "./tests/fixtures/tests_samples/COCO/000000039769.png" + question = "How many cats are there?" + + outputs = vqa_pipeline(image=image, question=question) + self.assertEqual(outputs, [{"answer": ANY(str)}]) + + outputs = vqa_pipeline({"image": image, "question": question}) + self.assertEqual(outputs, [{"answer": ANY(str)}]) + + outputs = vqa_pipeline([{"image": image, "question": question}, {"image": image, "question": question}]) + self.assertEqual(outputs, [[{"answer": ANY(str)}]] * 2) + + vqa_pipeline = pipeline( + "visual-question-answering", + model="hf-internal-testing/tiny-random-Blip2ForConditionalGeneration", + model_kwargs={"torch_dtype": torch.float16}, + device=0, + ) + self.assertEqual(vqa_pipeline.model.device, torch.device(0)) + self.assertEqual(vqa_pipeline.model.language_model.dtype, torch.float16) + self.assertEqual(vqa_pipeline.model.vision_model.dtype, torch.float16) + + outputs = vqa_pipeline(image=image, question=question) + self.assertEqual(outputs, [{"answer": ANY(str)}]) + @slow @require_torch def test_large_model_pt(self): @@ -109,6 +146,31 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase): [[{"score": 0.8799, "answer": "2"}, {"score": 0.296, "answer": "1"}]] * 2, ) + @slow + @require_torch + @require_torch_gpu + def test_large_model_pt_blip2(self): + vqa_pipeline = pipeline( + "visual-question-answering", + model="Salesforce/blip2-opt-2.7b", + model_kwargs={"torch_dtype": torch.float16}, + device=0, + ) + self.assertEqual(vqa_pipeline.model.device, torch.device(0)) + self.assertEqual(vqa_pipeline.model.language_model.dtype, torch.float16) + + image = "./tests/fixtures/tests_samples/COCO/000000039769.png" + question = "Question: how many cats are there? Answer:" + + outputs = vqa_pipeline(image=image, question=question) + self.assertEqual(outputs, [{"answer": "two"}]) + + outputs = vqa_pipeline({"image": image, "question": question}) + self.assertEqual(outputs, [{"answer": "two"}]) + + outputs = vqa_pipeline([{"image": image, "question": question}, {"image": image, "question": question}]) + self.assertEqual(outputs, [[{"answer": "two"}]] * 2) + @require_tf @unittest.skip("Visual question answering not implemented in TF") def test_small_model_tf(self):