Add support for OpenAI api "image_url" input in chat for image-text-to-text pipeline (#34562)
* add support for openai api image_url input * change continue to elif * Explicitely add support for OpenAI/TGI chat format * rewrite content to transformers chat format and add tests * Add support for typing of image type in chat templates * add base64 to possible image types * refactor nesting
This commit is contained in:
@@ -75,16 +75,32 @@ def retrieve_images_in_messages(
|
|||||||
retrieved_images = []
|
retrieved_images = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
for content in message["content"]:
|
for content in message["content"]:
|
||||||
if isinstance(content, dict) and content.get("type") == "image":
|
if isinstance(content, dict):
|
||||||
if "image" in content:
|
if content.get("type") == "image":
|
||||||
retrieved_images.append(content["image"])
|
for key in ["image", "url", "path", "base64"]:
|
||||||
elif idx_images < len(images):
|
if key in content:
|
||||||
|
retrieved_images.append(content[key])
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if idx_images < len(images):
|
||||||
retrieved_images.append(images[idx_images])
|
retrieved_images.append(images[idx_images])
|
||||||
idx_images += 1
|
idx_images += 1
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
|
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
|
||||||
)
|
)
|
||||||
|
# Add support for OpenAI/TGI chat format
|
||||||
|
elif content.get("type") == "image_url":
|
||||||
|
if isinstance(content.get("image_url"), dict) and "url" in content["image_url"]:
|
||||||
|
retrieved_images.append(content["image_url"]["url"])
|
||||||
|
# Rewrite content to be in the Transformers chat format
|
||||||
|
content["type"] = "image"
|
||||||
|
content["image"] = content["image_url"]["url"]
|
||||||
|
del content["image_url"]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Wrong format for 'image_url' content type. The content should have an 'image_url' dict with a 'url' key."
|
||||||
|
)
|
||||||
|
|
||||||
# The number of images passed should be consistent with the number of images in the chat without an image key
|
# The number of images passed should be consistent with the number of images in the chat without an image key
|
||||||
if idx_images != len(images):
|
if idx_images != len(images):
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import base64
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING, is_vision_available
|
from transformers import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING, is_vision_available
|
||||||
@@ -258,3 +259,46 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_model_pt_chat_template_image_url(self):
|
||||||
|
pipe = pipeline("image-text-to-text", model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "Describe this image in one sentence."},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
outputs = pipe(text=messages, return_full_text=False, max_new_tokens=10)[0]["generated_text"]
|
||||||
|
self.assertEqual(outputs, "The image captures the iconic Statue of Liberty, a")
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_model_pt_chat_template_image_url_base64(self):
|
||||||
|
with open("./tests/fixtures/tests_samples/COCO/000000039769.png", "rb") as image_file:
|
||||||
|
base64_image = base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
|
|
||||||
|
pipe = pipeline("image-text-to-text", model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "Describe this image in one sentence."},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
outputs = pipe(text=messages, return_full_text=False, max_new_tokens=10)[0]["generated_text"]
|
||||||
|
self.assertEqual(outputs, "Two cats are sleeping on a pink blanket, with")
|
||||||
|
|||||||
Reference in New Issue
Block a user