[image-to-text pipeline] Add conditional text support + GIT (#23362)

* First draft

* Remove print statements

* Add conditional generation

* Add more tests

* Remove scripts

* Remove BLIP specific linkes

* Add support for pix2struct

* Add fast test

* Address comment

* Fix style
This commit is contained in:
NielsRogge
2023-05-22 21:45:50 +02:00
committed by GitHub
parent e69feab8a1
commit 2f424d7979
3 changed files with 123 additions and 4 deletions

View File

@@ -20,6 +20,8 @@ if is_tf_available():
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_VISION_2_SEQ_MAPPING
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING
logger = logging.get_logger(__name__)
@@ -56,8 +58,13 @@ class ImageToTextPipeline(Pipeline):
TF_MODEL_FOR_VISION_2_SEQ_MAPPING if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING
)
def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None):
def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt=None):
forward_kwargs = {}
preprocess_params = {}
if prompt is not None:
preprocess_params["prompt"] = prompt
if generate_kwargs is not None:
forward_kwargs["generate_kwargs"] = generate_kwargs
if max_new_tokens is not None:
@@ -69,7 +76,7 @@ class ImageToTextPipeline(Pipeline):
" please use only one"
)
forward_kwargs["generate_kwargs"]["max_new_tokens"] = max_new_tokens
return {}, forward_kwargs, {}
return preprocess_params, forward_kwargs, {}
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
"""
@@ -98,9 +105,43 @@ class ImageToTextPipeline(Pipeline):
"""
return super().__call__(images, **kwargs)
def preprocess(self, image):
def preprocess(self, image, prompt=None):
image = load_image(image)
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
if prompt is not None:
if not isinstance(prompt, str):
raise ValueError(
f"Received an invalid text input, got - {type(prompt)} - but expected a single string. "
"Note also that one single text can be provided for conditional image to text generation."
)
model_type = self.model.config.model_type
if model_type == "git":
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
input_ids = self.tokenizer(text=prompt, add_special_tokens=False).input_ids
input_ids = [self.tokenizer.cls_token_id] + input_ids
input_ids = torch.tensor(input_ids).unsqueeze(0)
model_inputs.update({"input_ids": input_ids})
elif model_type == "pix2struct":
model_inputs = self.image_processor(images=image, header_text=prompt, return_tensors=self.framework)
elif model_type != "vision-encoder-decoder":
# vision-encoder-decoder does not support conditional generation
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
text_inputs = self.tokenizer(prompt, return_tensors=self.framework)
model_inputs.update(text_inputs)
else:
raise ValueError(f"Model type {model_type} does not support conditional text generation")
else:
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
if self.model.config.model_type == "git" and prompt is None:
model_inputs["input_ids"] = None
return model_inputs
def _forward(self, model_inputs, generate_kwargs=None):