Add support for post-processing kwargs in image-text-to-text pipeline (#35374)

* fix error and improve pipeline

* add processing_kwargs to apply_chat_template

* change default post_process kwarg to args

* Fix slow tests

* fix copies
This commit is contained in:
Yoni Gozlan
2025-02-18 17:43:36 -05:00
committed by GitHub
parent 9b479a245b
commit 9f51dc2535
8 changed files with 91 additions and 41 deletions

View File

@@ -682,7 +682,7 @@ class FuyuProcessor(ProcessorMixin):
return results return results
def post_process_image_text_to_text(self, generated_outputs): def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens=True, **kwargs):
""" """
Post-processes the output of `FuyuForConditionalGeneration` to only return the text output. Post-processes the output of `FuyuForConditionalGeneration` to only return the text output.
@@ -690,6 +690,10 @@ class FuyuProcessor(ProcessorMixin):
generated_outputs (`torch.Tensor` or `np.ndarray`): generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model. The output is expected to be a tensor of shape `(batch_size, sequence_length)` The output of the model. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
containing the token ids of the generated sequences. containing the token ids of the generated sequences.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
**kwargs:
Additional arguments to be passed to the tokenizer's `batch_decode method`.
Returns: Returns:
`List[str]`: The decoded text output. `List[str]`: The decoded text output.
@@ -706,7 +710,7 @@ class FuyuProcessor(ProcessorMixin):
for i, seq in enumerate(unpadded_output_sequences): for i, seq in enumerate(unpadded_output_sequences):
padded_output_sequences[i, : len(seq)] = torch.tensor(seq) padded_output_sequences[i, : len(seq)] = torch.tensor(seq)
return self.batch_decode(padded_output_sequences, skip_special_tokens=True) return self.batch_decode(padded_output_sequences, skip_special_tokens=skip_special_tokens, **kwargs)
def batch_decode(self, *args, **kwargs): def batch_decode(self, *args, **kwargs):
""" """

View File

@@ -428,7 +428,7 @@ class Kosmos2Processor(ProcessorMixin):
return clean_text_and_extract_entities_with_bboxes(caption) return clean_text_and_extract_entities_with_bboxes(caption)
return caption return caption
def post_process_image_text_to_text(self, generated_outputs): def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens=True, **kwargs):
""" """
Post-process the output of the model to decode the text. Post-process the output of the model to decode the text.
@@ -436,11 +436,15 @@ class Kosmos2Processor(ProcessorMixin):
generated_outputs (`torch.Tensor` or `np.ndarray`): generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`. or `(sequence_length,)`.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
**kwargs:
Additional arguments to be passed to the tokenizer's `batch_decode method`.
Returns: Returns:
`List[str]`: The decoded text. `List[str]`: The decoded text.
""" """
generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=True) generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs)
return [self.post_process_generation(text, cleanup_and_extract=False) for text in generated_texts] return [self.post_process_generation(text, cleanup_and_extract=False) for text in generated_texts]
@property @property

View File

@@ -346,7 +346,9 @@ class MllamaProcessor(ProcessorMixin):
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
def post_process_image_text_to_text(self, generated_outputs): def post_process_image_text_to_text(
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
):
""" """
Post-process the output of the model to decode the text. Post-process the output of the model to decode the text.
@@ -354,12 +356,21 @@ class MllamaProcessor(ProcessorMixin):
generated_outputs (`torch.Tensor` or `np.ndarray`): generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`. or `(sequence_length,)`.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
**kwargs:
Additional arguments to be passed to the tokenizer's `batch_decode method`.
Returns: Returns:
`List[str]`: The decoded text. `List[str]`: The decoded text.
""" """
return self.tokenizer.batch_decode( return self.tokenizer.batch_decode(
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False generated_outputs,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
) )
@property @property

View File

@@ -192,7 +192,9 @@ class Qwen2_5_VLProcessor(ProcessorMixin):
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
def post_process_image_text_to_text(self, generated_outputs): def post_process_image_text_to_text(
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
):
""" """
Post-process the output of the model to decode the text. Post-process the output of the model to decode the text.
@@ -200,12 +202,21 @@ class Qwen2_5_VLProcessor(ProcessorMixin):
generated_outputs (`torch.Tensor` or `np.ndarray`): generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`. or `(sequence_length,)`.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
**kwargs:
Additional arguments to be passed to the tokenizer's `batch_decode method`.
Returns: Returns:
`List[str]`: The decoded text. `List[str]`: The decoded text.
""" """
return self.tokenizer.batch_decode( return self.tokenizer.batch_decode(
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False generated_outputs,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
) )
@property @property

View File

@@ -170,7 +170,9 @@ class Qwen2VLProcessor(ProcessorMixin):
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
def post_process_image_text_to_text(self, generated_outputs): def post_process_image_text_to_text(
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
):
""" """
Post-process the output of the model to decode the text. Post-process the output of the model to decode the text.
@@ -178,12 +180,21 @@ class Qwen2VLProcessor(ProcessorMixin):
generated_outputs (`torch.Tensor` or `np.ndarray`): generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`. or `(sequence_length,)`.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
**kwargs:
Additional arguments to be passed to the tokenizer's `batch_decode method`.
Returns: Returns:
`List[str]`: The decoded text. `List[str]`: The decoded text.
""" """
return self.tokenizer.batch_decode( return self.tokenizer.batch_decode(
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False generated_outputs,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
) )
@property @property

View File

@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import enum import enum
from collections.abc import Iterable # pylint: disable=g-importing-member
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from ..processing_utils import ProcessingKwargs, Unpack from ..processing_utils import ProcessingKwargs, Unpack
@@ -71,6 +72,8 @@ def retrieve_images_in_messages(
""" """
if images is None: if images is None:
images = [] images = []
elif not isinstance(images, Iterable):
images = [images]
idx_images = 0 idx_images = 0
retrieved_images = [] retrieved_images = []
for message in messages: for message in messages:
@@ -188,14 +191,15 @@ class ImageTextToTextPipeline(Pipeline):
return_full_text=None, return_full_text=None,
return_tensors=None, return_tensors=None,
return_type=None, return_type=None,
clean_up_tokenization_spaces=None,
stop_sequence=None,
continue_final_message=None, continue_final_message=None,
**kwargs: Unpack[ProcessingKwargs], **kwargs: Unpack[ProcessingKwargs],
): ):
forward_kwargs = {} forward_kwargs = {}
preprocess_params = {} preprocess_params = {}
postprocess_params = {} postprocess_params = {}
preprocess_params.update(kwargs)
preprocess_params["processing_kwargs"] = kwargs
if timeout is not None: if timeout is not None:
preprocess_params["timeout"] = timeout preprocess_params["timeout"] = timeout
@@ -226,7 +230,16 @@ class ImageTextToTextPipeline(Pipeline):
postprocess_params["return_type"] = return_type postprocess_params["return_type"] = return_type
if continue_final_message is not None: if continue_final_message is not None:
postprocess_params["continue_final_message"] = continue_final_message postprocess_params["continue_final_message"] = continue_final_message
if clean_up_tokenization_spaces is not None:
postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces
if stop_sequence is not None:
stop_sequence_ids = self.processor.tokenizer.encode(stop_sequence, add_special_tokens=False)
if len(stop_sequence_ids) > 1:
logger.warning_once(
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
" the stop sequence will be used as the stop sequence string in the interim."
)
generate_kwargs["eos_token_id"] = stop_sequence_ids[0]
return preprocess_params, forward_kwargs, postprocess_params return preprocess_params, forward_kwargs, postprocess_params
def __call__( def __call__(
@@ -264,6 +277,8 @@ class ImageTextToTextPipeline(Pipeline):
return_full_text (`bool`, *optional*, defaults to `True`): return_full_text (`bool`, *optional*, defaults to `True`):
If set to `False` only added text is returned, otherwise the full text is returned. Cannot be If set to `False` only added text is returned, otherwise the full text is returned. Cannot be
specified at the same time as `return_text`. specified at the same time as `return_text`.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not to clean up the potential extra spaces in the text output.
continue_final_message( `bool`, *optional*): This indicates that you want the model to continue the continue_final_message( `bool`, *optional*): This indicates that you want the model to continue the
last message in the input chat rather than starting a new one, allowing you to "prefill" its response. last message in the input chat rather than starting a new one, allowing you to "prefill" its response.
By default this is `True` when the final message in the input chat has the `assistant` role and By default this is `True` when the final message in the input chat has the `assistant` role and
@@ -315,7 +330,7 @@ class ImageTextToTextPipeline(Pipeline):
return super().__call__({"images": images, "text": text}, **kwargs) return super().__call__({"images": images, "text": text}, **kwargs)
def preprocess(self, inputs=None, timeout=None, continue_final_message=None, processing_kwargs=None): def preprocess(self, inputs=None, timeout=None, continue_final_message=None, **processing_kwargs):
# In case we only have text inputs # In case we only have text inputs
if isinstance(inputs, (list, tuple, str)): if isinstance(inputs, (list, tuple, str)):
images = None images = None
@@ -332,6 +347,7 @@ class ImageTextToTextPipeline(Pipeline):
add_generation_prompt=not continue_final_message, add_generation_prompt=not continue_final_message,
continue_final_message=continue_final_message, continue_final_message=continue_final_message,
return_tensors=self.framework, return_tensors=self.framework,
**processing_kwargs,
) )
inputs_text = inputs inputs_text = inputs
images = inputs.images images = inputs.images
@@ -340,7 +356,7 @@ class ImageTextToTextPipeline(Pipeline):
inputs_text = inputs["text"] inputs_text = inputs["text"]
images = inputs["images"] images = inputs["images"]
images = load_images(images) images = load_images(images, timeout=timeout)
# if batched text inputs, we set padding to True unless specified otherwise # if batched text inputs, we set padding to True unless specified otherwise
if isinstance(text, (list, tuple)) and len(text) > 1: if isinstance(text, (list, tuple)) and len(text) > 1:
@@ -363,7 +379,9 @@ class ImageTextToTextPipeline(Pipeline):
return {"generated_sequence": generated_sequence, "prompt_text": prompt_text, "input_ids": input_ids} return {"generated_sequence": generated_sequence, "prompt_text": prompt_text, "input_ids": input_ids}
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, continue_final_message=None): def postprocess(
self, model_outputs, return_type=ReturnType.FULL_TEXT, continue_final_message=None, **postprocess_kwargs
):
input_texts = model_outputs["prompt_text"] input_texts = model_outputs["prompt_text"]
input_texts = [input_texts] if isinstance(input_texts, (str, Chat)) else input_texts input_texts = [input_texts] if isinstance(input_texts, (str, Chat)) else input_texts
generated_sequence = model_outputs["generated_sequence"] generated_sequence = model_outputs["generated_sequence"]
@@ -375,8 +393,8 @@ class ImageTextToTextPipeline(Pipeline):
] ]
# Decode inputs and outputs the same way to remove input text from generated text if present # Decode inputs and outputs the same way to remove input text from generated text if present
generated_texts = self.processor.post_process_image_text_to_text(generated_sequence) generated_texts = self.processor.post_process_image_text_to_text(generated_sequence, **postprocess_kwargs)
decoded_inputs = self.processor.post_process_image_text_to_text(input_ids) decoded_inputs = self.processor.post_process_image_text_to_text(input_ids, **postprocess_kwargs)
# Force consistent behavior for including the input text in the output # Force consistent behavior for including the input text in the output
if return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}: if return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:

View File

@@ -1392,7 +1392,7 @@ class ProcessorMixin(PushToHubMixin):
return out["input_ids"] return out["input_ids"]
return prompt return prompt
def post_process_image_text_to_text(self, generated_outputs): def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens=True, **kwargs):
""" """
Post-process the output of a vlm to decode the text. Post-process the output of a vlm to decode the text.
@@ -1400,11 +1400,15 @@ class ProcessorMixin(PushToHubMixin):
generated_outputs (`torch.Tensor` or `np.ndarray`): generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`. or `(sequence_length,)`.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
**kwargs:
Additional arguments to be passed to the tokenizer's `batch_decode method`.
Returns: Returns:
`List[str]`: The decoded text. `List[str]`: The decoded text.
""" """
return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True) return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs)
def _validate_images_text_input_order(images, text): def _validate_images_text_input_order(images, text):

View File

@@ -124,7 +124,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
], ],
} }
] ]
outputs = pipe([image_ny, image_chicago], text=messages) outputs = pipe([image_ny, image_chicago], text=messages, return_full_text=False, max_new_tokens=10)
self.assertEqual( self.assertEqual(
outputs, outputs,
[ [
@@ -139,20 +139,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
], ],
} }
], ],
"generated_text": [ "generated_text": "The first image shows a statue of Liberty in the",
{
"role": "user",
"content": [
{"type": "text", "text": "Whats the difference between these two images?"},
{"type": "image"},
{"type": "image"},
],
},
{
"role": "assistant",
"content": "The first image shows a statue of the Statue of Liberty in the foreground, while the second image shows",
},
],
} }
], ],
) )
@@ -179,7 +166,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
], ],
}, },
] ]
outputs = pipe(text=messages) outputs = pipe(text=messages, max_new_tokens=10)
self.assertEqual( self.assertEqual(
outputs, outputs,
[ [
@@ -213,7 +200,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
"content": [ "content": [
{ {
"type": "text", "type": "text",
"text": "There is a dog and a person in the image. The dog is sitting on the sand, and the person is sitting on", "text": "There is a dog and a person in the image. The dog is sitting",
} }
], ],
}, },
@@ -238,7 +225,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
], ],
} }
] ]
outputs = pipe(text=messages, return_full_text=False) outputs = pipe(text=messages, return_full_text=False, max_new_tokens=10)
self.assertEqual( self.assertEqual(
outputs, outputs,
[ [
@@ -255,7 +242,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
], ],
} }
], ],
"generated_text": "In the image, a woman is sitting on the sandy beach, her legs crossed in a relaxed manner", "generated_text": "In the image, a woman is sitting on the",
} }
], ],
) )
@@ -263,7 +250,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
@slow @slow
@require_torch @require_torch
def test_model_pt_chat_template_image_url(self): def test_model_pt_chat_template_image_url(self):
pipe = pipeline("image-text-to-text", model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf") pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
messages = [ messages = [
{ {
"role": "user", "role": "user",
@@ -279,7 +266,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
} }
] ]
outputs = pipe(text=messages, return_full_text=False, max_new_tokens=10)[0]["generated_text"] outputs = pipe(text=messages, return_full_text=False, max_new_tokens=10)[0]["generated_text"]
self.assertEqual(outputs, "The image captures the iconic Statue of Liberty, a") self.assertEqual(outputs, "A statue of liberty in the foreground of a city")
@slow @slow
@require_torch @require_torch