Process inputs directly in apply_chat_template in image-text-to-text pipeline (#35616)
* tokenize inputs directly in apply_chat_template * refactor processing * revert changes processing llava * Update docs * fix issue with str being iterable * add test chat text only * change function name
This commit is contained in:
@@ -160,7 +160,48 @@ outputs[0]["generated_text"]
|
|||||||
# with a yellow center in the foreground. The flower is surrounded by red and white flowers with green stems
|
# with a yellow center in the foreground. The flower is surrounded by red and white flowers with green stems
|
||||||
```
|
```
|
||||||
|
|
||||||
## Streaming
|
If you prefer, you can also load the images separately and pass them to the pipeline like so:
|
||||||
|
|
||||||
|
```python
|
||||||
|
pipe = pipeline("image-text-to-text", model="HuggingFaceTB/SmolVLM-256M-Instruct")
|
||||||
|
|
||||||
|
img_urls = [
|
||||||
|
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png",
|
||||||
|
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
|
||||||
|
]
|
||||||
|
images = [
|
||||||
|
Image.open(requests.get(img_urls[0], stream=True).raw),
|
||||||
|
Image.open(requests.get(img_urls[1], stream=True).raw),
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image"},
|
||||||
|
{"type": "image"},
|
||||||
|
{"type": "text", "text": "What do you see in these images?"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
outputs = pipe(text=messages, images=images, max_new_tokens=50, return_full_text=False)
|
||||||
|
outputs[0]["generated_text"]
|
||||||
|
" In the first image, there are two cats sitting on a plant. In the second image, there are flowers with a pinkish hue."
|
||||||
|
```
|
||||||
|
|
||||||
|
The images will still be included in the `"input_text"` field of the output:
|
||||||
|
|
||||||
|
```python
|
||||||
|
outputs[0]['input_text']
|
||||||
|
"""
|
||||||
|
[{'role': 'user',
|
||||||
|
'content': [{'type': 'image',
|
||||||
|
'image': <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=622x412>},
|
||||||
|
{'type': 'image',
|
||||||
|
'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=5184x3456>},
|
||||||
|
{'type': 'text', 'text': 'What do you see in these images?'}]}]## Streaming
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
We can use [text streaming](./generation_strategies#streaming) for a better generation experience. Transformers supports streaming with the [`TextStreamer`] or [`TextIteratorStreamer`] classes. We will use the [`TextIteratorStreamer`] with IDEFICS-8B.
|
We can use [text streaming](./generation_strategies#streaming) for a better generation experience. Transformers supports streaming with the [`TextStreamer`] or [`TextIteratorStreamer`] classes. We will use the [`TextIteratorStreamer`] with IDEFICS-8B.
|
||||||
|
|
||||||
|
|||||||
@@ -58,13 +58,12 @@ class Chat:
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
if not ("role" in message and "content" in message):
|
if not ("role" in message and "content" in message):
|
||||||
raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
|
raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
|
||||||
images = retrieve_images_in_messages(messages, images)
|
messages = add_images_to_messages(messages, images)
|
||||||
|
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
self.images = images
|
|
||||||
|
|
||||||
|
|
||||||
def retrieve_images_in_messages(
|
def add_images_to_messages(
|
||||||
messages: dict, images: Optional[Union[str, List[str], "Image.Image", List["Image.Image"]]]
|
messages: dict, images: Optional[Union[str, List[str], "Image.Image", List["Image.Image"]]]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -72,38 +71,35 @@ def retrieve_images_in_messages(
|
|||||||
"""
|
"""
|
||||||
if images is None:
|
if images is None:
|
||||||
images = []
|
images = []
|
||||||
elif not isinstance(images, Iterable):
|
elif not isinstance(images, Iterable) or isinstance(images, str):
|
||||||
images = [images]
|
images = [images]
|
||||||
idx_images = 0
|
idx_images = 0
|
||||||
retrieved_images = []
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
for content in message["content"]:
|
for content in message["content"]:
|
||||||
if isinstance(content, dict):
|
if not isinstance(content, dict):
|
||||||
if content.get("type") == "image":
|
continue
|
||||||
for key in ["image", "url", "path", "base64"]:
|
content_type = content.get("type")
|
||||||
if key in content:
|
if content_type == "image":
|
||||||
retrieved_images.append(content[key])
|
if not any(key in content for key in ["image", "url", "path", "base64"]):
|
||||||
break
|
if idx_images < len(images):
|
||||||
else:
|
# Insert the image passed as argument in the chat message
|
||||||
if idx_images < len(images):
|
content["image"] = images[idx_images]
|
||||||
retrieved_images.append(images[idx_images])
|
idx_images += 1
|
||||||
idx_images += 1
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
|
|
||||||
)
|
|
||||||
# Add support for OpenAI/TGI chat format
|
|
||||||
elif content.get("type") == "image_url":
|
|
||||||
if isinstance(content.get("image_url"), dict) and "url" in content["image_url"]:
|
|
||||||
retrieved_images.append(content["image_url"]["url"])
|
|
||||||
# Rewrite content to be in the Transformers chat format
|
|
||||||
content["type"] = "image"
|
|
||||||
content["image"] = content["image_url"]["url"]
|
|
||||||
del content["image_url"]
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Wrong format for 'image_url' content type. The content should have an 'image_url' dict with a 'url' key."
|
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
|
||||||
)
|
)
|
||||||
|
# Add support for OpenAI/TGI chat format
|
||||||
|
elif content_type == "image_url":
|
||||||
|
if isinstance(content.get("image_url"), dict) and "url" in content["image_url"]:
|
||||||
|
# Rewrite content to be in the Transformers chat format
|
||||||
|
content["type"] = "image"
|
||||||
|
content["image"] = content["image_url"]["url"]
|
||||||
|
del content["image_url"]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Wrong format for 'image_url' content type. The content should have an 'image_url' dict with a 'url' key."
|
||||||
|
)
|
||||||
|
|
||||||
# The number of images passed should be consistent with the number of images in the chat without an image key
|
# The number of images passed should be consistent with the number of images in the chat without an image key
|
||||||
if idx_images != len(images):
|
if idx_images != len(images):
|
||||||
@@ -111,7 +107,7 @@ def retrieve_images_in_messages(
|
|||||||
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
|
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
|
||||||
)
|
)
|
||||||
|
|
||||||
return retrieved_images
|
return messages
|
||||||
|
|
||||||
|
|
||||||
@add_end_docstrings(build_pipeline_init_args(has_processor=True))
|
@add_end_docstrings(build_pipeline_init_args(has_processor=True))
|
||||||
@@ -331,32 +327,30 @@ class ImageTextToTextPipeline(Pipeline):
|
|||||||
return super().__call__({"images": images, "text": text}, **kwargs)
|
return super().__call__({"images": images, "text": text}, **kwargs)
|
||||||
|
|
||||||
def preprocess(self, inputs=None, timeout=None, continue_final_message=None, **processing_kwargs):
|
def preprocess(self, inputs=None, timeout=None, continue_final_message=None, **processing_kwargs):
|
||||||
|
if isinstance(inputs, Chat):
|
||||||
|
# If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
|
||||||
|
# because very few models support multiple separate, consecutive assistant messages
|
||||||
|
if continue_final_message is None:
|
||||||
|
continue_final_message = inputs.messages[-1]["role"] == "assistant"
|
||||||
|
model_inputs = self.processor.apply_chat_template(
|
||||||
|
inputs.messages,
|
||||||
|
add_generation_prompt=not continue_final_message,
|
||||||
|
continue_final_message=continue_final_message,
|
||||||
|
return_tensors=self.framework,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
model_inputs["text"] = inputs
|
||||||
|
return model_inputs
|
||||||
# In case we only have text inputs
|
# In case we only have text inputs
|
||||||
if isinstance(inputs, (list, tuple, str)):
|
if isinstance(inputs, (list, tuple, str)):
|
||||||
images = None
|
images = None
|
||||||
text = inputs
|
text = inputs
|
||||||
inputs_text = inputs
|
inputs_text = inputs
|
||||||
else:
|
else:
|
||||||
if isinstance(inputs, Chat):
|
images = load_images(inputs["images"], timeout=timeout)
|
||||||
# If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
|
text = inputs["text"]
|
||||||
# because very few models support multiple separate, consecutive assistant messages
|
inputs_text = inputs["text"]
|
||||||
if continue_final_message is None:
|
|
||||||
continue_final_message = inputs.messages[-1]["role"] == "assistant"
|
|
||||||
text = self.processor.apply_chat_template(
|
|
||||||
inputs.messages,
|
|
||||||
add_generation_prompt=not continue_final_message,
|
|
||||||
continue_final_message=continue_final_message,
|
|
||||||
return_tensors=self.framework,
|
|
||||||
**processing_kwargs,
|
|
||||||
)
|
|
||||||
inputs_text = inputs
|
|
||||||
images = inputs.images
|
|
||||||
else:
|
|
||||||
text = inputs["text"]
|
|
||||||
inputs_text = inputs["text"]
|
|
||||||
images = inputs["images"]
|
|
||||||
|
|
||||||
images = load_images(images, timeout=timeout)
|
|
||||||
|
|
||||||
# if batched text inputs, we set padding to True unless specified otherwise
|
# if batched text inputs, we set padding to True unless specified otherwise
|
||||||
if isinstance(text, (list, tuple)) and len(text) > 1:
|
if isinstance(text, (list, tuple)) and len(text) > 1:
|
||||||
|
|||||||
@@ -66,6 +66,78 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_pt_token_text_only(self):
|
||||||
|
pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
|
||||||
|
text = "What is the capital of France? Assistant:"
|
||||||
|
|
||||||
|
outputs = pipe(text=text)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"input_text": "What is the capital of France? Assistant:",
|
||||||
|
"generated_text": "What is the capital of France? Assistant: The capital of France is Paris.",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Write a poem on Hugging Face, the company"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What is the capital of France?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
]
|
||||||
|
outputs = pipe(text=messages)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"input_text": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "Write a poem on Hugging Face, the company"}],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"generated_text": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "Write a poem on Hugging Face, the company"}],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hugging Face, a company of minds\nWith tools and services that make our lives easier\nFrom",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"input_text": [
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]}
|
||||||
|
],
|
||||||
|
"generated_text": [
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]},
|
||||||
|
{"role": "assistant", "content": "Paris"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_small_model_pt_token(self):
|
def test_small_model_pt_token(self):
|
||||||
pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
|
pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
|
||||||
@@ -124,7 +196,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
outputs = pipe([image_ny, image_chicago], text=messages, return_full_text=False, max_new_tokens=10)
|
outputs = pipe([image_ny, image_chicago], text=messages, return_full_text=True, max_new_tokens=10)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
outputs,
|
outputs,
|
||||||
[
|
[
|
||||||
@@ -134,12 +206,37 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
|
|||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "text", "text": "What’s the difference between these two images?"},
|
{"type": "text", "text": "What’s the difference between these two images?"},
|
||||||
{"type": "image"},
|
{
|
||||||
{"type": "image"},
|
"type": "image",
|
||||||
|
"image": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image": "https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg",
|
||||||
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"generated_text": "The first image shows a statue of Liberty in the",
|
"generated_text": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What’s the difference between these two images?"},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image": "https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "The first image shows a statue of Liberty in the",
|
||||||
|
},
|
||||||
|
],
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user