[image-to-text pipeline] Add conditional text support + GIT (#23362)

* First draft

* Remove print statements

* Add conditional generation

* Add more tests

* Remove scripts

* Remove BLIP specific linkes

* Add support for pix2struct

* Add fast test

* Address comment

* Fix style
This commit is contained in:
NielsRogge
2023-05-22 21:45:50 +02:00
committed by GitHub
parent e69feab8a1
commit 2f424d7979
3 changed files with 123 additions and 4 deletions

View File

@@ -14,6 +14,8 @@
import unittest
import requests
from transformers import MODEL_FOR_VISION_2_SEQ_MAPPING, TF_MODEL_FOR_VISION_2_SEQ_MAPPING, is_vision_available
from transformers.pipelines import pipeline
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, require_vision, slow
@@ -125,6 +127,15 @@ class ImageToTextPipelineTests(unittest.TestCase):
],
)
@require_torch
def test_small_model_pt_conditional(self):
pipe = pipeline("image-to-text", model="hf-internal-testing/tiny-random-BlipForConditionalGeneration")
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
prompt = "a photo of"
outputs = pipe(image, prompt=prompt)
self.assertTrue(outputs[0]["generated_text"].startswith(prompt))
@slow
@require_torch
def test_large_model_pt(self):
@@ -143,6 +154,71 @@ class ImageToTextPipelineTests(unittest.TestCase):
],
)
@slow
@require_torch
def test_generation_pt_blip(self):
pipe = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png"
image = Image.open(requests.get(url, stream=True).raw)
outputs = pipe(image)
self.assertEqual(outputs, [{"generated_text": "a pink pokemon pokemon with a blue shirt and a blue shirt"}])
@slow
@require_torch
def test_generation_pt_git(self):
pipe = pipeline("image-to-text", model="microsoft/git-base-coco")
url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png"
image = Image.open(requests.get(url, stream=True).raw)
outputs = pipe(image)
self.assertEqual(outputs, [{"generated_text": "a cartoon of a purple character."}])
@slow
@require_torch
def test_conditional_generation_pt_blip(self):
pipe = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "a photography of"
outputs = pipe(image, prompt=prompt)
self.assertEqual(outputs, [{"generated_text": "a photography of a volcano"}])
with self.assertRaises(ValueError):
outputs = pipe([image, image], prompt=[prompt, prompt])
@slow
@require_torch
def test_conditional_generation_pt_git(self):
pipe = pipeline("image-to-text", model="microsoft/git-base-coco")
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "a photo of a"
outputs = pipe(image, prompt=prompt)
self.assertEqual(outputs, [{"generated_text": "a photo of a tent with a tent and a tent in the background."}])
with self.assertRaises(ValueError):
outputs = pipe([image, image], prompt=[prompt, prompt])
@slow
@require_torch
def test_conditional_generation_pt_pix2struct(self):
pipe = pipeline("image-to-text", model="google/pix2struct-ai2d-base")
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud"
outputs = pipe(image, prompt=prompt)
self.assertEqual(outputs, [{"generated_text": "ash cloud"}])
with self.assertRaises(ValueError):
outputs = pipe([image, image], prompt=[prompt, prompt])
@slow
@require_tf
def test_large_model_tf(self):