Process inputs directly in apply_chat_template in image-text-to-text pipeline (#35616)

* tokenize inputs directly in apply_chat_template

* refactor processing

* revert changes processing llava

* Update docs

* fix issue with str being iterable

* add test chat text only

* change function name
This commit is contained in:
Yoni Gozlan
2025-04-23 13:31:33 -04:00
committed by GitHub
parent 80ea2c05c2
commit 5cd6b64059
3 changed files with 186 additions and 54 deletions

View File

@@ -160,7 +160,48 @@ outputs[0]["generated_text"]
# with a yellow center in the foreground. The flower is surrounded by red and white flowers with green stems # with a yellow center in the foreground. The flower is surrounded by red and white flowers with green stems
``` ```
## Streaming If you prefer, you can also load the images separately and pass them to the pipeline like so:
```python
pipe = pipeline("image-text-to-text", model="HuggingFaceTB/SmolVLM-256M-Instruct")
img_urls = [
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
]
images = [
Image.open(requests.get(img_urls[0], stream=True).raw),
Image.open(requests.get(img_urls[1], stream=True).raw),
]
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "image"},
{"type": "text", "text": "What do you see in these images?"},
],
}
]
outputs = pipe(text=messages, images=images, max_new_tokens=50, return_full_text=False)
outputs[0]["generated_text"]
" In the first image, there are two cats sitting on a plant. In the second image, there are flowers with a pinkish hue."
```
The images will still be included in the `"input_text"` field of the output:
```python
outputs[0]['input_text']
"""
[{'role': 'user',
'content': [{'type': 'image',
'image': <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=622x412>},
{'type': 'image',
'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=5184x3456>},
{'type': 'text', 'text': 'What do you see in these images?'}]}]## Streaming
"""
```
We can use [text streaming](./generation_strategies#streaming) for a better generation experience. Transformers supports streaming with the [`TextStreamer`] or [`TextIteratorStreamer`] classes. We will use the [`TextIteratorStreamer`] with IDEFICS-8B. We can use [text streaming](./generation_strategies#streaming) for a better generation experience. Transformers supports streaming with the [`TextStreamer`] or [`TextIteratorStreamer`] classes. We will use the [`TextIteratorStreamer`] with IDEFICS-8B.

View File

@@ -58,13 +58,12 @@ class Chat:
for message in messages: for message in messages:
if not ("role" in message and "content" in message): if not ("role" in message and "content" in message):
raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.") raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
images = retrieve_images_in_messages(messages, images) messages = add_images_to_messages(messages, images)
self.messages = messages self.messages = messages
self.images = images
def retrieve_images_in_messages( def add_images_to_messages(
messages: dict, images: Optional[Union[str, List[str], "Image.Image", List["Image.Image"]]] messages: dict, images: Optional[Union[str, List[str], "Image.Image", List["Image.Image"]]]
): ):
""" """
@@ -72,30 +71,27 @@ def retrieve_images_in_messages(
""" """
if images is None: if images is None:
images = [] images = []
elif not isinstance(images, Iterable): elif not isinstance(images, Iterable) or isinstance(images, str):
images = [images] images = [images]
idx_images = 0 idx_images = 0
retrieved_images = []
for message in messages: for message in messages:
for content in message["content"]: for content in message["content"]:
if isinstance(content, dict): if not isinstance(content, dict):
if content.get("type") == "image": continue
for key in ["image", "url", "path", "base64"]: content_type = content.get("type")
if key in content: if content_type == "image":
retrieved_images.append(content[key]) if not any(key in content for key in ["image", "url", "path", "base64"]):
break
else:
if idx_images < len(images): if idx_images < len(images):
retrieved_images.append(images[idx_images]) # Insert the image passed as argument in the chat message
content["image"] = images[idx_images]
idx_images += 1 idx_images += 1
else: else:
raise ValueError( raise ValueError(
"The number of images in the chat messages should be the same as the number of images passed to the pipeline." "The number of images in the chat messages should be the same as the number of images passed to the pipeline."
) )
# Add support for OpenAI/TGI chat format # Add support for OpenAI/TGI chat format
elif content.get("type") == "image_url": elif content_type == "image_url":
if isinstance(content.get("image_url"), dict) and "url" in content["image_url"]: if isinstance(content.get("image_url"), dict) and "url" in content["image_url"]:
retrieved_images.append(content["image_url"]["url"])
# Rewrite content to be in the Transformers chat format # Rewrite content to be in the Transformers chat format
content["type"] = "image" content["type"] = "image"
content["image"] = content["image_url"]["url"] content["image"] = content["image_url"]["url"]
@@ -111,7 +107,7 @@ def retrieve_images_in_messages(
"The number of images in the chat messages should be the same as the number of images passed to the pipeline." "The number of images in the chat messages should be the same as the number of images passed to the pipeline."
) )
return retrieved_images return messages
@add_end_docstrings(build_pipeline_init_args(has_processor=True)) @add_end_docstrings(build_pipeline_init_args(has_processor=True))
@@ -331,32 +327,30 @@ class ImageTextToTextPipeline(Pipeline):
return super().__call__({"images": images, "text": text}, **kwargs) return super().__call__({"images": images, "text": text}, **kwargs)
def preprocess(self, inputs=None, timeout=None, continue_final_message=None, **processing_kwargs): def preprocess(self, inputs=None, timeout=None, continue_final_message=None, **processing_kwargs):
if isinstance(inputs, Chat):
# If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
# because very few models support multiple separate, consecutive assistant messages
if continue_final_message is None:
continue_final_message = inputs.messages[-1]["role"] == "assistant"
model_inputs = self.processor.apply_chat_template(
inputs.messages,
add_generation_prompt=not continue_final_message,
continue_final_message=continue_final_message,
return_tensors=self.framework,
tokenize=True,
return_dict=True,
)
model_inputs["text"] = inputs
return model_inputs
# In case we only have text inputs # In case we only have text inputs
if isinstance(inputs, (list, tuple, str)): if isinstance(inputs, (list, tuple, str)):
images = None images = None
text = inputs text = inputs
inputs_text = inputs inputs_text = inputs
else: else:
if isinstance(inputs, Chat): images = load_images(inputs["images"], timeout=timeout)
# If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
# because very few models support multiple separate, consecutive assistant messages
if continue_final_message is None:
continue_final_message = inputs.messages[-1]["role"] == "assistant"
text = self.processor.apply_chat_template(
inputs.messages,
add_generation_prompt=not continue_final_message,
continue_final_message=continue_final_message,
return_tensors=self.framework,
**processing_kwargs,
)
inputs_text = inputs
images = inputs.images
else:
text = inputs["text"] text = inputs["text"]
inputs_text = inputs["text"] inputs_text = inputs["text"]
images = inputs["images"]
images = load_images(images, timeout=timeout)
# if batched text inputs, we set padding to True unless specified otherwise # if batched text inputs, we set padding to True unless specified otherwise
if isinstance(text, (list, tuple)) and len(text) > 1: if isinstance(text, (list, tuple)) and len(text) > 1:

View File

@@ -66,6 +66,78 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
], ],
) )
@require_torch
def test_small_model_pt_token_text_only(self):
pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
text = "What is the capital of France? Assistant:"
outputs = pipe(text=text)
self.assertEqual(
outputs,
[
{
"input_text": "What is the capital of France? Assistant:",
"generated_text": "What is the capital of France? Assistant: The capital of France is Paris.",
}
],
)
messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "Write a poem on Hugging Face, the company"},
],
},
],
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is the capital of France?"},
],
},
],
]
outputs = pipe(text=messages)
self.assertEqual(
outputs,
[
[
{
"input_text": [
{
"role": "user",
"content": [{"type": "text", "text": "Write a poem on Hugging Face, the company"}],
}
],
"generated_text": [
{
"role": "user",
"content": [{"type": "text", "text": "Write a poem on Hugging Face, the company"}],
},
{
"role": "assistant",
"content": "Hugging Face, a company of minds\nWith tools and services that make our lives easier\nFrom",
},
],
}
],
[
{
"input_text": [
{"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]}
],
"generated_text": [
{"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]},
{"role": "assistant", "content": "Paris"},
],
}
],
],
)
@require_torch @require_torch
def test_small_model_pt_token(self): def test_small_model_pt_token(self):
pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf") pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
@@ -124,7 +196,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
], ],
} }
] ]
outputs = pipe([image_ny, image_chicago], text=messages, return_full_text=False, max_new_tokens=10) outputs = pipe([image_ny, image_chicago], text=messages, return_full_text=True, max_new_tokens=10)
self.assertEqual( self.assertEqual(
outputs, outputs,
[ [
@@ -134,12 +206,37 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
"role": "user", "role": "user",
"content": [ "content": [
{"type": "text", "text": "Whats the difference between these two images?"}, {"type": "text", "text": "Whats the difference between these two images?"},
{"type": "image"}, {
{"type": "image"}, "type": "image",
"image": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
},
{
"type": "image",
"image": "https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg",
},
], ],
} }
], ],
"generated_text": "The first image shows a statue of Liberty in the", "generated_text": [
{
"role": "user",
"content": [
{"type": "text", "text": "Whats the difference between these two images?"},
{
"type": "image",
"image": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
},
{
"type": "image",
"image": "https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg",
},
],
},
{
"role": "assistant",
"content": "The first image shows a statue of Liberty in the",
},
],
} }
], ],
) )