From b99ca4d28b47fa7166e7882cb0695a5c0cc0d411 Mon Sep 17 00:00:00 2001 From: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> Date: Tue, 19 Nov 2024 11:08:37 -0500 Subject: [PATCH] Add support for OpenAI api "image_url" input in chat for image-text-to-text pipeline (#34562) * add support for openai api image_url input * change continue to elif * Explicitely add support for OpenAI/TGI chat format * rewrite content to transformers chat format and add tests * Add support for typing of image type in chat templates * add base64 to possible image types * refactor nesting --- .../pipelines/image_text_to_text.py | 36 ++++++++++----- .../test_pipelines_image_text_to_text.py | 44 +++++++++++++++++++ 2 files changed, 70 insertions(+), 10 deletions(-) diff --git a/src/transformers/pipelines/image_text_to_text.py b/src/transformers/pipelines/image_text_to_text.py index 39738ffc38..5afba0d7c0 100644 --- a/src/transformers/pipelines/image_text_to_text.py +++ b/src/transformers/pipelines/image_text_to_text.py @@ -75,16 +75,32 @@ def retrieve_images_in_messages( retrieved_images = [] for message in messages: for content in message["content"]: - if isinstance(content, dict) and content.get("type") == "image": - if "image" in content: - retrieved_images.append(content["image"]) - elif idx_images < len(images): - retrieved_images.append(images[idx_images]) - idx_images += 1 - else: - raise ValueError( - "The number of images in the chat messages should be the same as the number of images passed to the pipeline." - ) + if isinstance(content, dict): + if content.get("type") == "image": + for key in ["image", "url", "path", "base64"]: + if key in content: + retrieved_images.append(content[key]) + break + else: + if idx_images < len(images): + retrieved_images.append(images[idx_images]) + idx_images += 1 + else: + raise ValueError( + "The number of images in the chat messages should be the same as the number of images passed to the pipeline." + ) + # Add support for OpenAI/TGI chat format + elif content.get("type") == "image_url": + if isinstance(content.get("image_url"), dict) and "url" in content["image_url"]: + retrieved_images.append(content["image_url"]["url"]) + # Rewrite content to be in the Transformers chat format + content["type"] = "image" + content["image"] = content["image_url"]["url"] + del content["image_url"] + else: + raise ValueError( + "Wrong format for 'image_url' content type. The content should have an 'image_url' dict with a 'url' key." + ) # The number of images passed should be consistent with the number of images in the chat without an image key if idx_images != len(images): diff --git a/tests/pipelines/test_pipelines_image_text_to_text.py b/tests/pipelines/test_pipelines_image_text_to_text.py index b44b9decf9..7b9e17edd3 100644 --- a/tests/pipelines/test_pipelines_image_text_to_text.py +++ b/tests/pipelines/test_pipelines_image_text_to_text.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import unittest from transformers import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING, is_vision_available @@ -258,3 +259,46 @@ class ImageTextToTextPipelineTests(unittest.TestCase): } ], ) + + @slow + @require_torch + def test_model_pt_chat_template_image_url(self): + pipe = pipeline("image-text-to-text", model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf") + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" + }, + }, + {"type": "text", "text": "Describe this image in one sentence."}, + ], + } + ] + outputs = pipe(text=messages, return_full_text=False, max_new_tokens=10)[0]["generated_text"] + self.assertEqual(outputs, "The image captures the iconic Statue of Liberty, a") + + @slow + @require_torch + def test_model_pt_chat_template_image_url_base64(self): + with open("./tests/fixtures/tests_samples/COCO/000000039769.png", "rb") as image_file: + base64_image = base64.b64encode(image_file.read()).decode("utf-8") + + pipe = pipeline("image-text-to-text", model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf") + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}, + }, + {"type": "text", "text": "Describe this image in one sentence."}, + ], + } + ] + outputs = pipe(text=messages, return_full_text=False, max_new_tokens=10)[0]["generated_text"] + self.assertEqual(outputs, "Two cats are sleeping on a pink blanket, with")