Add support for OpenAI api "image_url" input in chat for image-text-to-text pipeline (#34562)

* add support for openai api image_url input

* change continue to elif

* Explicitely add support for OpenAI/TGI chat format

* rewrite content to transformers chat format and add tests

* Add support for typing of image type in chat templates

* add base64 to possible image types

* refactor nesting
This commit is contained in:
Yoni Gozlan
2024-11-19 11:08:37 -05:00
committed by GitHub
parent 15dd625a0f
commit b99ca4d28b
2 changed files with 70 additions and 10 deletions

View File

@@ -75,16 +75,32 @@ def retrieve_images_in_messages(
retrieved_images = [] retrieved_images = []
for message in messages: for message in messages:
for content in message["content"]: for content in message["content"]:
if isinstance(content, dict) and content.get("type") == "image": if isinstance(content, dict):
if "image" in content: if content.get("type") == "image":
retrieved_images.append(content["image"]) for key in ["image", "url", "path", "base64"]:
elif idx_images < len(images): if key in content:
retrieved_images.append(content[key])
break
else:
if idx_images < len(images):
retrieved_images.append(images[idx_images]) retrieved_images.append(images[idx_images])
idx_images += 1 idx_images += 1
else: else:
raise ValueError( raise ValueError(
"The number of images in the chat messages should be the same as the number of images passed to the pipeline." "The number of images in the chat messages should be the same as the number of images passed to the pipeline."
) )
# Add support for OpenAI/TGI chat format
elif content.get("type") == "image_url":
if isinstance(content.get("image_url"), dict) and "url" in content["image_url"]:
retrieved_images.append(content["image_url"]["url"])
# Rewrite content to be in the Transformers chat format
content["type"] = "image"
content["image"] = content["image_url"]["url"]
del content["image_url"]
else:
raise ValueError(
"Wrong format for 'image_url' content type. The content should have an 'image_url' dict with a 'url' key."
)
# The number of images passed should be consistent with the number of images in the chat without an image key # The number of images passed should be consistent with the number of images in the chat without an image key
if idx_images != len(images): if idx_images != len(images):

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import base64
import unittest import unittest
from transformers import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING, is_vision_available from transformers import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING, is_vision_available
@@ -258,3 +259,46 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
} }
], ],
) )
@slow
@require_torch
def test_model_pt_chat_template_image_url(self):
pipe = pipeline("image-text-to-text", model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
messages = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
},
},
{"type": "text", "text": "Describe this image in one sentence."},
],
}
]
outputs = pipe(text=messages, return_full_text=False, max_new_tokens=10)[0]["generated_text"]
self.assertEqual(outputs, "The image captures the iconic Statue of Liberty, a")
@slow
@require_torch
def test_model_pt_chat_template_image_url_base64(self):
with open("./tests/fixtures/tests_samples/COCO/000000039769.png", "rb") as image_file:
base64_image = base64.b64encode(image_file.read()).decode("utf-8")
pipe = pipeline("image-text-to-text", model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
messages = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
},
{"type": "text", "text": "Describe this image in one sentence."},
],
}
]
outputs = pipe(text=messages, return_full_text=False, max_new_tokens=10)[0]["generated_text"]
self.assertEqual(outputs, "Two cats are sleeping on a pink blanket, with")