Add support for OpenAI api "image_url" input in chat for image-text-to-text pipeline (#34562)
* add support for openai api image_url input * change continue to elif * Explicitely add support for OpenAI/TGI chat format * rewrite content to transformers chat format and add tests * Add support for typing of image type in chat templates * add base64 to possible image types * refactor nesting
This commit is contained in:
@@ -75,16 +75,32 @@ def retrieve_images_in_messages(
|
||||
retrieved_images = []
|
||||
for message in messages:
|
||||
for content in message["content"]:
|
||||
if isinstance(content, dict) and content.get("type") == "image":
|
||||
if "image" in content:
|
||||
retrieved_images.append(content["image"])
|
||||
elif idx_images < len(images):
|
||||
if isinstance(content, dict):
|
||||
if content.get("type") == "image":
|
||||
for key in ["image", "url", "path", "base64"]:
|
||||
if key in content:
|
||||
retrieved_images.append(content[key])
|
||||
break
|
||||
else:
|
||||
if idx_images < len(images):
|
||||
retrieved_images.append(images[idx_images])
|
||||
idx_images += 1
|
||||
else:
|
||||
raise ValueError(
|
||||
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
|
||||
)
|
||||
# Add support for OpenAI/TGI chat format
|
||||
elif content.get("type") == "image_url":
|
||||
if isinstance(content.get("image_url"), dict) and "url" in content["image_url"]:
|
||||
retrieved_images.append(content["image_url"]["url"])
|
||||
# Rewrite content to be in the Transformers chat format
|
||||
content["type"] = "image"
|
||||
content["image"] = content["image_url"]["url"]
|
||||
del content["image_url"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Wrong format for 'image_url' content type. The content should have an 'image_url' dict with a 'url' key."
|
||||
)
|
||||
|
||||
# The number of images passed should be consistent with the number of images in the chat without an image key
|
||||
if idx_images != len(images):
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
import unittest
|
||||
|
||||
from transformers import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING, is_vision_available
|
||||
@@ -258,3 +259,46 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_model_pt_chat_template_image_url(self):
|
||||
pipe = pipeline("image-text-to-text", model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "Describe this image in one sentence."},
|
||||
],
|
||||
}
|
||||
]
|
||||
outputs = pipe(text=messages, return_full_text=False, max_new_tokens=10)[0]["generated_text"]
|
||||
self.assertEqual(outputs, "The image captures the iconic Statue of Liberty, a")
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_model_pt_chat_template_image_url_base64(self):
|
||||
with open("./tests/fixtures/tests_samples/COCO/000000039769.png", "rb") as image_file:
|
||||
base64_image = base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
pipe = pipeline("image-text-to-text", model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
|
||||
},
|
||||
{"type": "text", "text": "Describe this image in one sentence."},
|
||||
],
|
||||
}
|
||||
]
|
||||
outputs = pipe(text=messages, return_full_text=False, max_new_tokens=10)[0]["generated_text"]
|
||||
self.assertEqual(outputs, "Two cats are sleeping on a pink blanket, with")
|
||||
|
||||
Reference in New Issue
Block a user