[image-to-text pipeline] Add conditional text support + GIT (#23362)
* First draft * Remove print statements * Add conditional generation * Add more tests * Remove scripts * Remove BLIP specific linkes * Add support for pix2struct * Add fast test * Address comment * Fix style
This commit is contained in:
@@ -20,6 +20,8 @@ if is_tf_available():
|
||||
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_VISION_2_SEQ_MAPPING
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -56,8 +58,13 @@ class ImageToTextPipeline(Pipeline):
|
||||
TF_MODEL_FOR_VISION_2_SEQ_MAPPING if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING
|
||||
)
|
||||
|
||||
def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None):
|
||||
def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt=None):
|
||||
forward_kwargs = {}
|
||||
preprocess_params = {}
|
||||
|
||||
if prompt is not None:
|
||||
preprocess_params["prompt"] = prompt
|
||||
|
||||
if generate_kwargs is not None:
|
||||
forward_kwargs["generate_kwargs"] = generate_kwargs
|
||||
if max_new_tokens is not None:
|
||||
@@ -69,7 +76,7 @@ class ImageToTextPipeline(Pipeline):
|
||||
" please use only one"
|
||||
)
|
||||
forward_kwargs["generate_kwargs"]["max_new_tokens"] = max_new_tokens
|
||||
return {}, forward_kwargs, {}
|
||||
return preprocess_params, forward_kwargs, {}
|
||||
|
||||
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
|
||||
"""
|
||||
@@ -98,9 +105,43 @@ class ImageToTextPipeline(Pipeline):
|
||||
"""
|
||||
return super().__call__(images, **kwargs)
|
||||
|
||||
def preprocess(self, image):
|
||||
def preprocess(self, image, prompt=None):
|
||||
image = load_image(image)
|
||||
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
|
||||
|
||||
if prompt is not None:
|
||||
if not isinstance(prompt, str):
|
||||
raise ValueError(
|
||||
f"Received an invalid text input, got - {type(prompt)} - but expected a single string. "
|
||||
"Note also that one single text can be provided for conditional image to text generation."
|
||||
)
|
||||
|
||||
model_type = self.model.config.model_type
|
||||
|
||||
if model_type == "git":
|
||||
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
|
||||
input_ids = self.tokenizer(text=prompt, add_special_tokens=False).input_ids
|
||||
input_ids = [self.tokenizer.cls_token_id] + input_ids
|
||||
input_ids = torch.tensor(input_ids).unsqueeze(0)
|
||||
model_inputs.update({"input_ids": input_ids})
|
||||
|
||||
elif model_type == "pix2struct":
|
||||
model_inputs = self.image_processor(images=image, header_text=prompt, return_tensors=self.framework)
|
||||
|
||||
elif model_type != "vision-encoder-decoder":
|
||||
# vision-encoder-decoder does not support conditional generation
|
||||
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
|
||||
text_inputs = self.tokenizer(prompt, return_tensors=self.framework)
|
||||
model_inputs.update(text_inputs)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Model type {model_type} does not support conditional text generation")
|
||||
|
||||
else:
|
||||
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
|
||||
|
||||
if self.model.config.model_type == "git" and prompt is None:
|
||||
model_inputs["input_ids"] = None
|
||||
|
||||
return model_inputs
|
||||
|
||||
def _forward(self, model_inputs, generate_kwargs=None):
|
||||
|
||||
Reference in New Issue
Block a user