Add support for OpenAI api "image_url" input in chat for image-text-to-text pipeline (#34562)

* add support for openai api image_url input

* change continue to elif

* Explicitely add support for OpenAI/TGI chat format

* rewrite content to transformers chat format and add tests

* Add support for typing of image type in chat templates

* add base64 to possible image types

* refactor nesting
This commit is contained in:
Yoni Gozlan
2024-11-19 11:08:37 -05:00
committed by GitHub
parent 15dd625a0f
commit b99ca4d28b
2 changed files with 70 additions and 10 deletions

View File

@@ -75,16 +75,32 @@ def retrieve_images_in_messages(
retrieved_images = []
for message in messages:
for content in message["content"]:
if isinstance(content, dict) and content.get("type") == "image":
if "image" in content:
retrieved_images.append(content["image"])
elif idx_images < len(images):
if isinstance(content, dict):
if content.get("type") == "image":
for key in ["image", "url", "path", "base64"]:
if key in content:
retrieved_images.append(content[key])
break
else:
if idx_images < len(images):
retrieved_images.append(images[idx_images])
idx_images += 1
else:
raise ValueError(
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
)
# Add support for OpenAI/TGI chat format
elif content.get("type") == "image_url":
if isinstance(content.get("image_url"), dict) and "url" in content["image_url"]:
retrieved_images.append(content["image_url"]["url"])
# Rewrite content to be in the Transformers chat format
content["type"] = "image"
content["image"] = content["image_url"]["url"]
del content["image_url"]
else:
raise ValueError(
"Wrong format for 'image_url' content type. The content should have an 'image_url' dict with a 'url' key."
)
# The number of images passed should be consistent with the number of images in the chat without an image key
if idx_images != len(images):

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import unittest
from transformers import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING, is_vision_available
@@ -258,3 +259,46 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
}
],
)
@slow
@require_torch
def test_model_pt_chat_template_image_url(self):
pipe = pipeline("image-text-to-text", model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
messages = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
},
},
{"type": "text", "text": "Describe this image in one sentence."},
],
}
]
outputs = pipe(text=messages, return_full_text=False, max_new_tokens=10)[0]["generated_text"]
self.assertEqual(outputs, "The image captures the iconic Statue of Liberty, a")
@slow
@require_torch
def test_model_pt_chat_template_image_url_base64(self):
with open("./tests/fixtures/tests_samples/COCO/000000039769.png", "rb") as image_file:
base64_image = base64.b64encode(image_file.read()).decode("utf-8")
pipe = pipeline("image-text-to-text", model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
messages = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
},
{"type": "text", "text": "Describe this image in one sentence."},
],
}
]
outputs = pipe(text=messages, return_full_text=False, max_new_tokens=10)[0]["generated_text"]
self.assertEqual(outputs, "Two cats are sleeping on a pink blanket, with")