Add support for post-processing kwargs in image-text-to-text pipeline (#35374)
* fix error and improve pipeline * add processing_kwargs to apply_chat_template * change default post_process kwarg to args * Fix slow tests * fix copies
This commit is contained in:
@@ -682,7 +682,7 @@ class FuyuProcessor(ProcessorMixin):
|
||||
|
||||
return results
|
||||
|
||||
def post_process_image_text_to_text(self, generated_outputs):
|
||||
def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens=True, **kwargs):
|
||||
"""
|
||||
Post-processes the output of `FuyuForConditionalGeneration` to only return the text output.
|
||||
|
||||
@@ -690,6 +690,10 @@ class FuyuProcessor(ProcessorMixin):
|
||||
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||
The output of the model. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||
containing the token ids of the generated sequences.
|
||||
skip_special_tokens (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
|
||||
**kwargs:
|
||||
Additional arguments to be passed to the tokenizer's `batch_decode method`.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The decoded text output.
|
||||
@@ -706,7 +710,7 @@ class FuyuProcessor(ProcessorMixin):
|
||||
for i, seq in enumerate(unpadded_output_sequences):
|
||||
padded_output_sequences[i, : len(seq)] = torch.tensor(seq)
|
||||
|
||||
return self.batch_decode(padded_output_sequences, skip_special_tokens=True)
|
||||
return self.batch_decode(padded_output_sequences, skip_special_tokens=skip_special_tokens, **kwargs)
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
|
||||
@@ -428,7 +428,7 @@ class Kosmos2Processor(ProcessorMixin):
|
||||
return clean_text_and_extract_entities_with_bboxes(caption)
|
||||
return caption
|
||||
|
||||
def post_process_image_text_to_text(self, generated_outputs):
|
||||
def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens=True, **kwargs):
|
||||
"""
|
||||
Post-process the output of the model to decode the text.
|
||||
|
||||
@@ -436,11 +436,15 @@ class Kosmos2Processor(ProcessorMixin):
|
||||
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||
or `(sequence_length,)`.
|
||||
skip_special_tokens (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
|
||||
**kwargs:
|
||||
Additional arguments to be passed to the tokenizer's `batch_decode method`.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The decoded text.
|
||||
"""
|
||||
generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=True)
|
||||
generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs)
|
||||
return [self.post_process_generation(text, cleanup_and_extract=False) for text in generated_texts]
|
||||
|
||||
@property
|
||||
|
||||
@@ -346,7 +346,9 @@ class MllamaProcessor(ProcessorMixin):
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
def post_process_image_text_to_text(self, generated_outputs):
|
||||
def post_process_image_text_to_text(
|
||||
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
|
||||
):
|
||||
"""
|
||||
Post-process the output of the model to decode the text.
|
||||
|
||||
@@ -354,12 +356,21 @@ class MllamaProcessor(ProcessorMixin):
|
||||
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||
or `(sequence_length,)`.
|
||||
skip_special_tokens (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
|
||||
Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
|
||||
**kwargs:
|
||||
Additional arguments to be passed to the tokenizer's `batch_decode method`.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The decoded text.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(
|
||||
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
generated_outputs,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -192,7 +192,9 @@ class Qwen2_5_VLProcessor(ProcessorMixin):
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
def post_process_image_text_to_text(self, generated_outputs):
|
||||
def post_process_image_text_to_text(
|
||||
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
|
||||
):
|
||||
"""
|
||||
Post-process the output of the model to decode the text.
|
||||
|
||||
@@ -200,12 +202,21 @@ class Qwen2_5_VLProcessor(ProcessorMixin):
|
||||
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||
or `(sequence_length,)`.
|
||||
skip_special_tokens (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
|
||||
Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
|
||||
**kwargs:
|
||||
Additional arguments to be passed to the tokenizer's `batch_decode method`.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The decoded text.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(
|
||||
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
generated_outputs,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -170,7 +170,9 @@ class Qwen2VLProcessor(ProcessorMixin):
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
def post_process_image_text_to_text(self, generated_outputs):
|
||||
def post_process_image_text_to_text(
|
||||
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
|
||||
):
|
||||
"""
|
||||
Post-process the output of the model to decode the text.
|
||||
|
||||
@@ -178,12 +180,21 @@ class Qwen2VLProcessor(ProcessorMixin):
|
||||
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||
or `(sequence_length,)`.
|
||||
skip_special_tokens (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
|
||||
Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
|
||||
**kwargs:
|
||||
Additional arguments to be passed to the tokenizer's `batch_decode method`.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The decoded text.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(
|
||||
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
generated_outputs,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
from collections.abc import Iterable # pylint: disable=g-importing-member
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from ..processing_utils import ProcessingKwargs, Unpack
|
||||
@@ -71,6 +72,8 @@ def retrieve_images_in_messages(
|
||||
"""
|
||||
if images is None:
|
||||
images = []
|
||||
elif not isinstance(images, Iterable):
|
||||
images = [images]
|
||||
idx_images = 0
|
||||
retrieved_images = []
|
||||
for message in messages:
|
||||
@@ -188,14 +191,15 @@ class ImageTextToTextPipeline(Pipeline):
|
||||
return_full_text=None,
|
||||
return_tensors=None,
|
||||
return_type=None,
|
||||
clean_up_tokenization_spaces=None,
|
||||
stop_sequence=None,
|
||||
continue_final_message=None,
|
||||
**kwargs: Unpack[ProcessingKwargs],
|
||||
):
|
||||
forward_kwargs = {}
|
||||
preprocess_params = {}
|
||||
postprocess_params = {}
|
||||
|
||||
preprocess_params["processing_kwargs"] = kwargs
|
||||
preprocess_params.update(kwargs)
|
||||
|
||||
if timeout is not None:
|
||||
preprocess_params["timeout"] = timeout
|
||||
@@ -226,7 +230,16 @@ class ImageTextToTextPipeline(Pipeline):
|
||||
postprocess_params["return_type"] = return_type
|
||||
if continue_final_message is not None:
|
||||
postprocess_params["continue_final_message"] = continue_final_message
|
||||
|
||||
if clean_up_tokenization_spaces is not None:
|
||||
postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces
|
||||
if stop_sequence is not None:
|
||||
stop_sequence_ids = self.processor.tokenizer.encode(stop_sequence, add_special_tokens=False)
|
||||
if len(stop_sequence_ids) > 1:
|
||||
logger.warning_once(
|
||||
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
|
||||
" the stop sequence will be used as the stop sequence string in the interim."
|
||||
)
|
||||
generate_kwargs["eos_token_id"] = stop_sequence_ids[0]
|
||||
return preprocess_params, forward_kwargs, postprocess_params
|
||||
|
||||
def __call__(
|
||||
@@ -264,6 +277,8 @@ class ImageTextToTextPipeline(Pipeline):
|
||||
return_full_text (`bool`, *optional*, defaults to `True`):
|
||||
If set to `False` only added text is returned, otherwise the full text is returned. Cannot be
|
||||
specified at the same time as `return_text`.
|
||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clean up the potential extra spaces in the text output.
|
||||
continue_final_message( `bool`, *optional*): This indicates that you want the model to continue the
|
||||
last message in the input chat rather than starting a new one, allowing you to "prefill" its response.
|
||||
By default this is `True` when the final message in the input chat has the `assistant` role and
|
||||
@@ -315,7 +330,7 @@ class ImageTextToTextPipeline(Pipeline):
|
||||
|
||||
return super().__call__({"images": images, "text": text}, **kwargs)
|
||||
|
||||
def preprocess(self, inputs=None, timeout=None, continue_final_message=None, processing_kwargs=None):
|
||||
def preprocess(self, inputs=None, timeout=None, continue_final_message=None, **processing_kwargs):
|
||||
# In case we only have text inputs
|
||||
if isinstance(inputs, (list, tuple, str)):
|
||||
images = None
|
||||
@@ -332,6 +347,7 @@ class ImageTextToTextPipeline(Pipeline):
|
||||
add_generation_prompt=not continue_final_message,
|
||||
continue_final_message=continue_final_message,
|
||||
return_tensors=self.framework,
|
||||
**processing_kwargs,
|
||||
)
|
||||
inputs_text = inputs
|
||||
images = inputs.images
|
||||
@@ -340,7 +356,7 @@ class ImageTextToTextPipeline(Pipeline):
|
||||
inputs_text = inputs["text"]
|
||||
images = inputs["images"]
|
||||
|
||||
images = load_images(images)
|
||||
images = load_images(images, timeout=timeout)
|
||||
|
||||
# if batched text inputs, we set padding to True unless specified otherwise
|
||||
if isinstance(text, (list, tuple)) and len(text) > 1:
|
||||
@@ -363,7 +379,9 @@ class ImageTextToTextPipeline(Pipeline):
|
||||
|
||||
return {"generated_sequence": generated_sequence, "prompt_text": prompt_text, "input_ids": input_ids}
|
||||
|
||||
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, continue_final_message=None):
|
||||
def postprocess(
|
||||
self, model_outputs, return_type=ReturnType.FULL_TEXT, continue_final_message=None, **postprocess_kwargs
|
||||
):
|
||||
input_texts = model_outputs["prompt_text"]
|
||||
input_texts = [input_texts] if isinstance(input_texts, (str, Chat)) else input_texts
|
||||
generated_sequence = model_outputs["generated_sequence"]
|
||||
@@ -375,8 +393,8 @@ class ImageTextToTextPipeline(Pipeline):
|
||||
]
|
||||
|
||||
# Decode inputs and outputs the same way to remove input text from generated text if present
|
||||
generated_texts = self.processor.post_process_image_text_to_text(generated_sequence)
|
||||
decoded_inputs = self.processor.post_process_image_text_to_text(input_ids)
|
||||
generated_texts = self.processor.post_process_image_text_to_text(generated_sequence, **postprocess_kwargs)
|
||||
decoded_inputs = self.processor.post_process_image_text_to_text(input_ids, **postprocess_kwargs)
|
||||
|
||||
# Force consistent behavior for including the input text in the output
|
||||
if return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
|
||||
|
||||
@@ -1392,7 +1392,7 @@ class ProcessorMixin(PushToHubMixin):
|
||||
return out["input_ids"]
|
||||
return prompt
|
||||
|
||||
def post_process_image_text_to_text(self, generated_outputs):
|
||||
def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens=True, **kwargs):
|
||||
"""
|
||||
Post-process the output of a vlm to decode the text.
|
||||
|
||||
@@ -1400,11 +1400,15 @@ class ProcessorMixin(PushToHubMixin):
|
||||
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||
or `(sequence_length,)`.
|
||||
skip_special_tokens (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
|
||||
**kwargs:
|
||||
Additional arguments to be passed to the tokenizer's `batch_decode method`.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The decoded text.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)
|
||||
return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs)
|
||||
|
||||
|
||||
def _validate_images_text_input_order(images, text):
|
||||
|
||||
@@ -124,7 +124,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
|
||||
],
|
||||
}
|
||||
]
|
||||
outputs = pipe([image_ny, image_chicago], text=messages)
|
||||
outputs = pipe([image_ny, image_chicago], text=messages, return_full_text=False, max_new_tokens=10)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
@@ -139,20 +139,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
|
||||
],
|
||||
}
|
||||
],
|
||||
"generated_text": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What’s the difference between these two images?"},
|
||||
{"type": "image"},
|
||||
{"type": "image"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The first image shows a statue of the Statue of Liberty in the foreground, while the second image shows",
|
||||
},
|
||||
],
|
||||
"generated_text": "The first image shows a statue of Liberty in the",
|
||||
}
|
||||
],
|
||||
)
|
||||
@@ -179,7 +166,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
|
||||
],
|
||||
},
|
||||
]
|
||||
outputs = pipe(text=messages)
|
||||
outputs = pipe(text=messages, max_new_tokens=10)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
@@ -213,7 +200,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "There is a dog and a person in the image. The dog is sitting on the sand, and the person is sitting on",
|
||||
"text": "There is a dog and a person in the image. The dog is sitting",
|
||||
}
|
||||
],
|
||||
},
|
||||
@@ -238,7 +225,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
|
||||
],
|
||||
}
|
||||
]
|
||||
outputs = pipe(text=messages, return_full_text=False)
|
||||
outputs = pipe(text=messages, return_full_text=False, max_new_tokens=10)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
@@ -255,7 +242,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
|
||||
],
|
||||
}
|
||||
],
|
||||
"generated_text": "In the image, a woman is sitting on the sandy beach, her legs crossed in a relaxed manner",
|
||||
"generated_text": "In the image, a woman is sitting on the",
|
||||
}
|
||||
],
|
||||
)
|
||||
@@ -263,7 +250,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
|
||||
@slow
|
||||
@require_torch
|
||||
def test_model_pt_chat_template_image_url(self):
|
||||
pipe = pipeline("image-text-to-text", model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
||||
pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -279,7 +266,7 @@ class ImageTextToTextPipelineTests(unittest.TestCase):
|
||||
}
|
||||
]
|
||||
outputs = pipe(text=messages, return_full_text=False, max_new_tokens=10)[0]["generated_text"]
|
||||
self.assertEqual(outputs, "The image captures the iconic Statue of Liberty, a")
|
||||
self.assertEqual(outputs, "A statue of liberty in the foreground of a city")
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user