Process inputs directly in apply_chat_template in image-text-to-text pipeline (#35616)

* tokenize inputs directly in apply_chat_template

* refactor processing

* revert changes processing llava

* Update docs

* fix issue with str being iterable

* add test chat text only

* change function name
This commit is contained in:
Yoni Gozlan
2025-04-23 13:31:33 -04:00
committed by GitHub
parent 80ea2c05c2
commit 5cd6b64059
3 changed files with 186 additions and 54 deletions

View File

@@ -160,7 +160,48 @@ outputs[0]["generated_text"]
# with a yellow center in the foreground. The flower is surrounded by red and white flowers with green stems
```
## Streaming
If you prefer, you can also load the images separately and pass them to the pipeline like so:
```python
pipe = pipeline("image-text-to-text", model="HuggingFaceTB/SmolVLM-256M-Instruct")
img_urls = [
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
]
images = [
Image.open(requests.get(img_urls[0], stream=True).raw),
Image.open(requests.get(img_urls[1], stream=True).raw),
]
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "image"},
{"type": "text", "text": "What do you see in these images?"},
],
}
]
outputs = pipe(text=messages, images=images, max_new_tokens=50, return_full_text=False)
outputs[0]["generated_text"]
" In the first image, there are two cats sitting on a plant. In the second image, there are flowers with a pinkish hue."
```
The images will still be included in the `"input_text"` field of the output:
```python
outputs[0]['input_text']
"""
[{'role': 'user',
'content': [{'type': 'image',
'image': <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=622x412>},
{'type': 'image',
'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=5184x3456>},
{'type': 'text', 'text': 'What do you see in these images?'}]}]## Streaming
"""
```
We can use [text streaming](./generation_strategies#streaming) for a better generation experience. Transformers supports streaming with the [`TextStreamer`] or [`TextIteratorStreamer`] classes. We will use the [`TextIteratorStreamer`] with IDEFICS-8B.

View File

@@ -58,13 +58,12 @@ class Chat:
for message in messages:
if not ("role" in message and "content" in message):
raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
images = retrieve_images_in_messages(messages, images)
messages = add_images_to_messages(messages, images)
self.messages = messages
self.images = images
def retrieve_images_in_messages(
def add_images_to_messages(
messages: dict, images: Optional[Union[str, List[str], "Image.Image", List["Image.Image"]]]
):
"""
@@ -72,30 +71,27 @@ def retrieve_images_in_messages(
"""
if images is None:
images = []
elif not isinstance(images, Iterable):
elif not isinstance(images, Iterable) or isinstance(images, str):
images = [images]
idx_images = 0
retrieved_images = []
for message in messages:
for content in message["content"]:
if isinstance(content, dict):
if content.get("type") == "image":
for key in ["image", "url", "path", "base64"]:
if key in content:
retrieved_images.append(content[key])
break
else:
if not isinstance(content, dict):
continue
content_type = content.get("type")
if content_type == "image":
if not any(key in content for key in ["image", "url", "path", "base64"]):
if idx_images < len(images):
retrieved_images.append(images[idx_images])
# Insert the image passed as argument in the chat message
content["image"] = images[idx_images]
idx_images += 1
else:
raise ValueError(
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
)
# Add support for OpenAI/TGI chat format
elif content.get("type") == "image_url":
elif content_type == "image_url":
if isinstance(content.get("image_url"), dict) and "url" in content["image_url"]:
retrieved_images.append(content["image_url"]["url"])
# Rewrite content to be in the Transformers chat format
content["type"] = "image"
content["image"] = content["image_url"]["url"]
@@ -111,7 +107,7 @@ def retrieve_images_in_messages(
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
)
return retrieved_images
return messages
@add_end_docstrings(build_pipeline_init_args(has_processor=True))
@@ -331,32 +327,30 @@ class ImageTextToTextPipeline(Pipeline):
return super().__call__({"images": images, "text": text}, **kwargs)
def preprocess(self, inputs=None, timeout=None, continue_final_message=None, **processing_kwargs):
if isinstance(inputs, Chat):
# If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
# because very few models support multiple separate, consecutive assistant messages
if continue_final_message is None:
continue_final_message = inputs.messages[-1]["role"] == "assistant"
model_inputs = self.processor.apply_chat_template(
inputs.messages,
add_generation_prompt=not continue_final_message,
continue_final_message=continue_final_message,
return_tensors=self.framework,
tokenize=True,
return_dict=True,
)
model_inputs["text"] = inputs
return model_inputs
# In case we only have text inputs
if isinstance(inputs, (list, tuple, str)):
images = None
text = inputs
inputs_text = inputs
else:
if isinstance(inputs, Chat):
# If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
# because very few models support multiple separate, consecutive assistant messages
if continue_final_message is None:
continue_final_message = inputs.messages[-1]["role"] == "assistant"
text = self.processor.apply_chat_template(
inputs.messages,
add_generation_prompt=not continue_final_message,
continue_final_message=continue_final_message,
return_tensors=self.framework,
**processing_kwargs,
)
inputs_text = inputs
images = inputs.images
else:
images = load_images(inputs["images"], timeout=timeout)
text = inputs["text"]
inputs_text = inputs["text"]
images = inputs["images"]
images = load_images(images, timeout=timeout)
# if batched text inputs, we set padding to True unless specified otherwise
if isinstance(text, (list, tuple)) and len(text) > 1:

View File

@@ -66,6 +66,78 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
],
)
@require_torch
def test_small_model_pt_token_text_only(self):
pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
text = "What is the capital of France? Assistant:"
outputs = pipe(text=text)
self.assertEqual(
outputs,
[
{
"input_text": "What is the capital of France? Assistant:",
"generated_text": "What is the capital of France? Assistant: The capital of France is Paris.",
}
],
)
messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "Write a poem on Hugging Face, the company"},
],
},
],
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is the capital of France?"},
],
},
],
]
outputs = pipe(text=messages)
self.assertEqual(
outputs,
[
[
{
"input_text": [
{
"role": "user",
"content": [{"type": "text", "text": "Write a poem on Hugging Face, the company"}],
}
],
"generated_text": [
{
"role": "user",
"content": [{"type": "text", "text": "Write a poem on Hugging Face, the company"}],
},
{
"role": "assistant",
"content": "Hugging Face, a company of minds\nWith tools and services that make our lives easier\nFrom",
},
],
}
],
[
{
"input_text": [
{"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]}
],
"generated_text": [
{"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]},
{"role": "assistant", "content": "Paris"},
],
}
],
],
)
@require_torch
def test_small_model_pt_token(self):
pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
@@ -124,7 +196,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
],
}
]
outputs = pipe([image_ny, image_chicago], text=messages, return_full_text=False, max_new_tokens=10)
outputs = pipe([image_ny, image_chicago], text=messages, return_full_text=True, max_new_tokens=10)
self.assertEqual(
outputs,
[
@@ -134,12 +206,37 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
"role": "user",
"content": [
{"type": "text", "text": "Whats the difference between these two images?"},
{"type": "image"},
{"type": "image"},
{
"type": "image",
"image": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
},
{
"type": "image",
"image": "https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg",
},
],
}
],
"generated_text": "The first image shows a statue of Liberty in the",
"generated_text": [
{
"role": "user",
"content": [
{"type": "text", "text": "Whats the difference between these two images?"},
{
"type": "image",
"image": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
},
{
"type": "image",
"image": "https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg",
},
],
},
{
"role": "assistant",
"content": "The first image shows a statue of Liberty in the",
},
],
}
],
)