Add support for OpenAI api "image_url" input in chat for image-text-to-text pipeline (#34562)

* add support for openai api image_url input

* change continue to elif

* Explicitely add support for OpenAI/TGI chat format

* rewrite content to transformers chat format and add tests

* Add support for typing of image type in chat templates

* add base64 to possible image types

* refactor nesting
This commit is contained in:
Yoni Gozlan
2024-11-19 11:08:37 -05:00
committed by GitHub
parent 15dd625a0f
commit b99ca4d28b
2 changed files with 70 additions and 10 deletions

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import unittest
from transformers import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING, is_vision_available
@@ -258,3 +259,46 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
}
],
)
@slow
@require_torch
def test_model_pt_chat_template_image_url(self):
pipe = pipeline("image-text-to-text", model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
messages = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
},
},
{"type": "text", "text": "Describe this image in one sentence."},
],
}
]
outputs = pipe(text=messages, return_full_text=False, max_new_tokens=10)[0]["generated_text"]
self.assertEqual(outputs, "The image captures the iconic Statue of Liberty, a")
@slow
@require_torch
def test_model_pt_chat_template_image_url_base64(self):
with open("./tests/fixtures/tests_samples/COCO/000000039769.png", "rb") as image_file:
base64_image = base64.b64encode(image_file.read()).decode("utf-8")
pipe = pipeline("image-text-to-text", model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
messages = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
},
{"type": "text", "text": "Describe this image in one sentence."},
],
}
]
outputs = pipe(text=messages, return_full_text=False, max_new_tokens=10)[0]["generated_text"]
self.assertEqual(outputs, "Two cats are sleeping on a pink blanket, with")