Add Image To Text Generation pipeline (#18821)
* Add Image2TextGenerationPipeline to supported pipelines * Add Flax and Tensorflow support * Add Flax and Tensorflow small tests * Add default model for Tensorflow * Add docstring * Fix doc style * Add tiny models for pytorch and flax * Remove flax from pipeline. Fix tests * Use ydshieh/vit-gpt2-coco-en as a default for both PyTorch and Tensorflow * Fix Tensorflow support Co-authored-by: Olivier Dehaene <olivier@huggingface.co>
This commit is contained in:
@@ -29,6 +29,7 @@ There are two categories of pipeline abstractions to be aware about:
|
||||
- [`FillMaskPipeline`]
|
||||
- [`ImageClassificationPipeline`]
|
||||
- [`ImageSegmentationPipeline`]
|
||||
- [`Image2TextGenerationPipeline`]
|
||||
- [`ObjectDetectionPipeline`]
|
||||
- [`QuestionAnsweringPipeline`]
|
||||
- [`SummarizationPipeline`]
|
||||
@@ -365,6 +366,12 @@ That should enable you to do all the custom code you want.
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### Image2TextGenerationPipeline
|
||||
|
||||
[[autodoc]] Image2TextGenerationPipeline
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### NerPipeline
|
||||
|
||||
[[autodoc]] NerPipeline
|
||||
|
||||
@@ -384,6 +384,7 @@ _import_structure = {
|
||||
"CsvPipelineDataFormat",
|
||||
"FeatureExtractionPipeline",
|
||||
"FillMaskPipeline",
|
||||
"Image2TextGenerationPipeline",
|
||||
"ImageClassificationPipeline",
|
||||
"ImageSegmentationPipeline",
|
||||
"JsonPipelineDataFormat",
|
||||
@@ -3191,6 +3192,7 @@ if TYPE_CHECKING:
|
||||
CsvPipelineDataFormat,
|
||||
FeatureExtractionPipeline,
|
||||
FillMaskPipeline,
|
||||
Image2TextGenerationPipeline,
|
||||
ImageClassificationPipeline,
|
||||
ImageSegmentationPipeline,
|
||||
JsonPipelineDataFormat,
|
||||
|
||||
@@ -53,6 +53,7 @@ from .base import (
|
||||
from .conversational import Conversation, ConversationalPipeline
|
||||
from .feature_extraction import FeatureExtractionPipeline
|
||||
from .fill_mask import FillMaskPipeline
|
||||
from .image2text_generation import Image2TextGenerationPipeline
|
||||
from .image_classification import ImageClassificationPipeline
|
||||
from .image_segmentation import ImageSegmentationPipeline
|
||||
from .object_detection import ObjectDetectionPipeline
|
||||
@@ -90,6 +91,7 @@ if is_tf_available():
|
||||
TFAutoModelForSequenceClassification,
|
||||
TFAutoModelForTableQuestionAnswering,
|
||||
TFAutoModelForTokenClassification,
|
||||
TFAutoModelForVision2Seq,
|
||||
)
|
||||
|
||||
if is_torch_available():
|
||||
@@ -118,6 +120,7 @@ if is_torch_available():
|
||||
AutoModelForSpeechSeq2Seq,
|
||||
AutoModelForTableQuestionAnswering,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelForVision2Seq,
|
||||
AutoModelForVisualQuestionAnswering,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
@@ -302,6 +305,18 @@ SUPPORTED_TASKS = {
|
||||
"default": {"model": {"pt": ("facebook/detr-resnet-50-panoptic", "fc15262")}},
|
||||
"type": "image",
|
||||
},
|
||||
"image2text-generation": {
|
||||
"impl": Image2TextGenerationPipeline,
|
||||
"tf": (TFAutoModelForVision2Seq,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForVision2Seq,) if is_torch_available() else (),
|
||||
"default": {
|
||||
"model": {
|
||||
"pt": ("ydshieh/vit-gpt2-coco-en", "65636df"),
|
||||
"tf": ("ydshieh/vit-gpt2-coco-en", "65636df"),
|
||||
}
|
||||
},
|
||||
"type": "multimodal",
|
||||
},
|
||||
"object-detection": {
|
||||
"impl": ObjectDetectionPipeline,
|
||||
"tf": (),
|
||||
@@ -317,7 +332,7 @@ NO_TOKENIZER_TASKS = set()
|
||||
# any tokenizer/feature_extractor might be use for a given model so we cannot
|
||||
# use the statically defined TOKENIZER_MAPPING and FEATURE_EXTRACTOR_MAPPING to
|
||||
# see if the model defines such objects or not.
|
||||
MULTI_MODEL_CONFIGS = {"VisionTextDualEncoderConfig", "SpeechEncoderDecoderConfig"}
|
||||
MULTI_MODEL_CONFIGS = {"SpeechEncoderDecoderConfig", "VisionEncoderDecoderConfig", "VisionTextDualEncoderConfig"}
|
||||
for task, values in SUPPORTED_TASKS.items():
|
||||
if values["type"] == "text":
|
||||
NO_FEATURE_EXTRACTOR_TASKS.add(task)
|
||||
|
||||
96
src/transformers/pipelines/image2text_generation.py
Normal file
96
src/transformers/pipelines/image2text_generation.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from typing import List, Union
|
||||
|
||||
from ..utils import (
|
||||
add_end_docstrings,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
requires_backends,
|
||||
)
|
||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from ..image_utils import load_image
|
||||
|
||||
if is_tf_available():
|
||||
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_VISION_2_SEQ_MAPPING
|
||||
|
||||
if is_torch_available():
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class Image2TextGenerationPipeline(Pipeline):
|
||||
"""
|
||||
Image2Text Generation pipeline using a `AutoModelForVision2Seq`. This pipeline predicts a caption for a given
|
||||
image.
|
||||
|
||||
This image to text generation pipeline can currently be loaded from pipeline() using the following task identifier:
|
||||
"image2text-generation".
|
||||
|
||||
See the list of available models on
|
||||
[huggingface.co/models](https://huggingface.co/models?pipeline_tag=image-to-text).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
requires_backends(self, "vision")
|
||||
self.check_model_type(
|
||||
TF_MODEL_FOR_VISION_2_SEQ_MAPPING if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING
|
||||
)
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
return {}, {}, {}
|
||||
|
||||
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
|
||||
"""
|
||||
Assign labels to the image(s) passed as inputs.
|
||||
|
||||
Args:
|
||||
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
The pipeline handles three types of images:
|
||||
|
||||
- A string containing a HTTP(s) link pointing to an image
|
||||
- A string containing a local path to an image
|
||||
- An image loaded in PIL directly
|
||||
|
||||
The pipeline accepts either a single image or a batch of images.
|
||||
|
||||
Return:
|
||||
A list or a list of list of `dict`: Each result comes as a dictionary with the following key:
|
||||
|
||||
- **generated_text** (`str`) -- The generated text.
|
||||
"""
|
||||
return super().__call__(images, **kwargs)
|
||||
|
||||
def preprocess(self, image):
|
||||
image = load_image(image)
|
||||
model_inputs = self.feature_extractor(images=image, return_tensors=self.framework)
|
||||
return model_inputs
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
# FIXME: We need to pop here due to a difference in how `generation_utils.py` and `generation_tf_utils.py`
|
||||
# parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas
|
||||
# the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name`
|
||||
# in the `_prepare_model_inputs` method.
|
||||
inputs = model_inputs.pop(self.model.main_input_name)
|
||||
model_outputs = self.model.generate(inputs, **model_inputs)
|
||||
return model_outputs
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
records = []
|
||||
for output_ids in model_outputs:
|
||||
record = {
|
||||
"generated_text": self.tokenizer.decode(
|
||||
output_ids,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
}
|
||||
records.append(record)
|
||||
return records
|
||||
171
tests/pipelines/test_pipelines_image2text_generation.py
Normal file
171
tests/pipelines/test_pipelines_image2text_generation.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import MODEL_FOR_VISION_2_SEQ_MAPPING, TF_MODEL_FOR_VISION_2_SEQ_MAPPING, is_vision_available
|
||||
from transformers.pipelines import pipeline
|
||||
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, require_vision, slow
|
||||
|
||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
else:
|
||||
|
||||
class Image:
|
||||
@staticmethod
|
||||
def open(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
@require_vision
|
||||
class Image2TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor):
|
||||
pipe = pipeline("image2text-generation", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
examples = [
|
||||
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
]
|
||||
return pipe, examples
|
||||
|
||||
def run_pipeline_test(self, pipe, examples):
|
||||
outputs = pipe(examples)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
[{"generated_text": ANY(str)}],
|
||||
[{"generated_text": ANY(str)}],
|
||||
],
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_small_model_tf(self):
|
||||
pipe = pipeline("image2text-generation", model="hf-internal-testing/tiny-random-vit-gpt2")
|
||||
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||
|
||||
outputs = pipe(image)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{
|
||||
"generated_text": (
|
||||
" intermedi intermedi intermedi intermedi intermedi "
|
||||
"explorer explorer explorer explorer explorer explorer "
|
||||
"explorer medicine medicine medicine medicine medicine "
|
||||
"medicine medicine"
|
||||
)
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
outputs = pipe([image, image])
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
[
|
||||
{
|
||||
"generated_text": (
|
||||
" intermedi intermedi intermedi intermedi intermedi "
|
||||
"explorer explorer explorer explorer explorer explorer "
|
||||
"explorer medicine medicine medicine medicine medicine "
|
||||
"medicine medicine"
|
||||
)
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"generated_text": (
|
||||
" intermedi intermedi intermedi intermedi intermedi "
|
||||
"explorer explorer explorer explorer explorer explorer "
|
||||
"explorer medicine medicine medicine medicine medicine "
|
||||
"medicine medicine"
|
||||
)
|
||||
},
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
pipe = pipeline("image2text-generation", model="hf-internal-testing/tiny-random-vit-gpt2")
|
||||
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||
|
||||
outputs = pipe(image)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{
|
||||
"generated_text": "growthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthGOGO"
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
outputs = pipe([image, image])
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
[
|
||||
{
|
||||
"generated_text": "growthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthGOGO"
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"generated_text": "growthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthGOGO"
|
||||
}
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_large_model_pt(self):
|
||||
pipe = pipeline("image2text-generation", model="ydshieh/vit-gpt2-coco-en")
|
||||
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||
|
||||
outputs = pipe(image)
|
||||
self.assertEqual(outputs, [{"generated_text": "a cat laying on a blanket next to a cat laying on a bed "}])
|
||||
|
||||
outputs = pipe([image, image])
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
[{"generated_text": "a cat laying on a blanket next to a cat laying on a bed "}],
|
||||
[{"generated_text": "a cat laying on a blanket next to a cat laying on a bed "}],
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_tf
|
||||
def test_large_model_tf(self):
|
||||
pipe = pipeline("image2text-generation", model="ydshieh/vit-gpt2-coco-en")
|
||||
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||
|
||||
outputs = pipe(image)
|
||||
self.assertEqual(outputs, [{"generated_text": "a cat laying on a blanket next to a cat laying on a bed "}])
|
||||
|
||||
outputs = pipe([image, image])
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
[{"generated_text": "a cat laying on a blanket next to a cat laying on a bed "}],
|
||||
[{"generated_text": "a cat laying on a blanket next to a cat laying on a bed "}],
|
||||
],
|
||||
)
|
||||
Reference in New Issue
Block a user