[image-to-text pipeline] Add conditional text support + GIT (#23362)
* First draft * Remove print statements * Add conditional generation * Add more tests * Remove scripts * Remove BLIP specific linkes * Add support for pix2struct * Add fast test * Address comment * Fix style
This commit is contained in:
@@ -14,6 +14,8 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from transformers import MODEL_FOR_VISION_2_SEQ_MAPPING, TF_MODEL_FOR_VISION_2_SEQ_MAPPING, is_vision_available
|
||||
from transformers.pipelines import pipeline
|
||||
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, require_vision, slow
|
||||
@@ -125,6 +127,15 @@ class ImageToTextPipelineTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt_conditional(self):
|
||||
pipe = pipeline("image-to-text", model="hf-internal-testing/tiny-random-BlipForConditionalGeneration")
|
||||
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||
prompt = "a photo of"
|
||||
|
||||
outputs = pipe(image, prompt=prompt)
|
||||
self.assertTrue(outputs[0]["generated_text"].startswith(prompt))
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_large_model_pt(self):
|
||||
@@ -143,6 +154,71 @@ class ImageToTextPipelineTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_generation_pt_blip(self):
|
||||
pipe = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
|
||||
url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
outputs = pipe(image)
|
||||
self.assertEqual(outputs, [{"generated_text": "a pink pokemon pokemon with a blue shirt and a blue shirt"}])
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_generation_pt_git(self):
|
||||
pipe = pipeline("image-to-text", model="microsoft/git-base-coco")
|
||||
url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
outputs = pipe(image)
|
||||
self.assertEqual(outputs, [{"generated_text": "a cartoon of a purple character."}])
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_conditional_generation_pt_blip(self):
|
||||
pipe = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
prompt = "a photography of"
|
||||
|
||||
outputs = pipe(image, prompt=prompt)
|
||||
self.assertEqual(outputs, [{"generated_text": "a photography of a volcano"}])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
outputs = pipe([image, image], prompt=[prompt, prompt])
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_conditional_generation_pt_git(self):
|
||||
pipe = pipeline("image-to-text", model="microsoft/git-base-coco")
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
prompt = "a photo of a"
|
||||
|
||||
outputs = pipe(image, prompt=prompt)
|
||||
self.assertEqual(outputs, [{"generated_text": "a photo of a tent with a tent and a tent in the background."}])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
outputs = pipe([image, image], prompt=[prompt, prompt])
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_conditional_generation_pt_pix2struct(self):
|
||||
pipe = pipeline("image-to-text", model="google/pix2struct-ai2d-base")
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
prompt = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud"
|
||||
|
||||
outputs = pipe(image, prompt=prompt)
|
||||
self.assertEqual(outputs, [{"generated_text": "ash cloud"}])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
outputs = pipe([image, image], prompt=[prompt, prompt])
|
||||
|
||||
@slow
|
||||
@require_tf
|
||||
def test_large_model_tf(self):
|
||||
|
||||
Reference in New Issue
Block a user