Add Image To Text Generation pipeline (#18821)
* Add Image2TextGenerationPipeline to supported pipelines * Add Flax and Tensorflow support * Add Flax and Tensorflow small tests * Add default model for Tensorflow * Add docstring * Fix doc style * Add tiny models for pytorch and flax * Remove flax from pipeline. Fix tests * Use ydshieh/vit-gpt2-coco-en as a default for both PyTorch and Tensorflow * Fix Tensorflow support Co-authored-by: Olivier Dehaene <olivier@huggingface.co>
This commit is contained in:
@@ -29,6 +29,7 @@ There are two categories of pipeline abstractions to be aware about:
|
|||||||
- [`FillMaskPipeline`]
|
- [`FillMaskPipeline`]
|
||||||
- [`ImageClassificationPipeline`]
|
- [`ImageClassificationPipeline`]
|
||||||
- [`ImageSegmentationPipeline`]
|
- [`ImageSegmentationPipeline`]
|
||||||
|
- [`Image2TextGenerationPipeline`]
|
||||||
- [`ObjectDetectionPipeline`]
|
- [`ObjectDetectionPipeline`]
|
||||||
- [`QuestionAnsweringPipeline`]
|
- [`QuestionAnsweringPipeline`]
|
||||||
- [`SummarizationPipeline`]
|
- [`SummarizationPipeline`]
|
||||||
@@ -365,6 +366,12 @@ That should enable you to do all the custom code you want.
|
|||||||
- __call__
|
- __call__
|
||||||
- all
|
- all
|
||||||
|
|
||||||
|
### Image2TextGenerationPipeline
|
||||||
|
|
||||||
|
[[autodoc]] Image2TextGenerationPipeline
|
||||||
|
- __call__
|
||||||
|
- all
|
||||||
|
|
||||||
### NerPipeline
|
### NerPipeline
|
||||||
|
|
||||||
[[autodoc]] NerPipeline
|
[[autodoc]] NerPipeline
|
||||||
|
|||||||
@@ -384,6 +384,7 @@ _import_structure = {
|
|||||||
"CsvPipelineDataFormat",
|
"CsvPipelineDataFormat",
|
||||||
"FeatureExtractionPipeline",
|
"FeatureExtractionPipeline",
|
||||||
"FillMaskPipeline",
|
"FillMaskPipeline",
|
||||||
|
"Image2TextGenerationPipeline",
|
||||||
"ImageClassificationPipeline",
|
"ImageClassificationPipeline",
|
||||||
"ImageSegmentationPipeline",
|
"ImageSegmentationPipeline",
|
||||||
"JsonPipelineDataFormat",
|
"JsonPipelineDataFormat",
|
||||||
@@ -3191,6 +3192,7 @@ if TYPE_CHECKING:
|
|||||||
CsvPipelineDataFormat,
|
CsvPipelineDataFormat,
|
||||||
FeatureExtractionPipeline,
|
FeatureExtractionPipeline,
|
||||||
FillMaskPipeline,
|
FillMaskPipeline,
|
||||||
|
Image2TextGenerationPipeline,
|
||||||
ImageClassificationPipeline,
|
ImageClassificationPipeline,
|
||||||
ImageSegmentationPipeline,
|
ImageSegmentationPipeline,
|
||||||
JsonPipelineDataFormat,
|
JsonPipelineDataFormat,
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ from .base import (
|
|||||||
from .conversational import Conversation, ConversationalPipeline
|
from .conversational import Conversation, ConversationalPipeline
|
||||||
from .feature_extraction import FeatureExtractionPipeline
|
from .feature_extraction import FeatureExtractionPipeline
|
||||||
from .fill_mask import FillMaskPipeline
|
from .fill_mask import FillMaskPipeline
|
||||||
|
from .image2text_generation import Image2TextGenerationPipeline
|
||||||
from .image_classification import ImageClassificationPipeline
|
from .image_classification import ImageClassificationPipeline
|
||||||
from .image_segmentation import ImageSegmentationPipeline
|
from .image_segmentation import ImageSegmentationPipeline
|
||||||
from .object_detection import ObjectDetectionPipeline
|
from .object_detection import ObjectDetectionPipeline
|
||||||
@@ -90,6 +91,7 @@ if is_tf_available():
|
|||||||
TFAutoModelForSequenceClassification,
|
TFAutoModelForSequenceClassification,
|
||||||
TFAutoModelForTableQuestionAnswering,
|
TFAutoModelForTableQuestionAnswering,
|
||||||
TFAutoModelForTokenClassification,
|
TFAutoModelForTokenClassification,
|
||||||
|
TFAutoModelForVision2Seq,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -118,6 +120,7 @@ if is_torch_available():
|
|||||||
AutoModelForSpeechSeq2Seq,
|
AutoModelForSpeechSeq2Seq,
|
||||||
AutoModelForTableQuestionAnswering,
|
AutoModelForTableQuestionAnswering,
|
||||||
AutoModelForTokenClassification,
|
AutoModelForTokenClassification,
|
||||||
|
AutoModelForVision2Seq,
|
||||||
AutoModelForVisualQuestionAnswering,
|
AutoModelForVisualQuestionAnswering,
|
||||||
)
|
)
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -302,6 +305,18 @@ SUPPORTED_TASKS = {
|
|||||||
"default": {"model": {"pt": ("facebook/detr-resnet-50-panoptic", "fc15262")}},
|
"default": {"model": {"pt": ("facebook/detr-resnet-50-panoptic", "fc15262")}},
|
||||||
"type": "image",
|
"type": "image",
|
||||||
},
|
},
|
||||||
|
"image2text-generation": {
|
||||||
|
"impl": Image2TextGenerationPipeline,
|
||||||
|
"tf": (TFAutoModelForVision2Seq,) if is_tf_available() else (),
|
||||||
|
"pt": (AutoModelForVision2Seq,) if is_torch_available() else (),
|
||||||
|
"default": {
|
||||||
|
"model": {
|
||||||
|
"pt": ("ydshieh/vit-gpt2-coco-en", "65636df"),
|
||||||
|
"tf": ("ydshieh/vit-gpt2-coco-en", "65636df"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "multimodal",
|
||||||
|
},
|
||||||
"object-detection": {
|
"object-detection": {
|
||||||
"impl": ObjectDetectionPipeline,
|
"impl": ObjectDetectionPipeline,
|
||||||
"tf": (),
|
"tf": (),
|
||||||
@@ -317,7 +332,7 @@ NO_TOKENIZER_TASKS = set()
|
|||||||
# any tokenizer/feature_extractor might be use for a given model so we cannot
|
# any tokenizer/feature_extractor might be use for a given model so we cannot
|
||||||
# use the statically defined TOKENIZER_MAPPING and FEATURE_EXTRACTOR_MAPPING to
|
# use the statically defined TOKENIZER_MAPPING and FEATURE_EXTRACTOR_MAPPING to
|
||||||
# see if the model defines such objects or not.
|
# see if the model defines such objects or not.
|
||||||
MULTI_MODEL_CONFIGS = {"VisionTextDualEncoderConfig", "SpeechEncoderDecoderConfig"}
|
MULTI_MODEL_CONFIGS = {"SpeechEncoderDecoderConfig", "VisionEncoderDecoderConfig", "VisionTextDualEncoderConfig"}
|
||||||
for task, values in SUPPORTED_TASKS.items():
|
for task, values in SUPPORTED_TASKS.items():
|
||||||
if values["type"] == "text":
|
if values["type"] == "text":
|
||||||
NO_FEATURE_EXTRACTOR_TASKS.add(task)
|
NO_FEATURE_EXTRACTOR_TASKS.add(task)
|
||||||
|
|||||||
96
src/transformers/pipelines/image2text_generation.py
Normal file
96
src/transformers/pipelines/image2text_generation.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from ..utils import (
|
||||||
|
add_end_docstrings,
|
||||||
|
is_tf_available,
|
||||||
|
is_torch_available,
|
||||||
|
is_vision_available,
|
||||||
|
logging,
|
||||||
|
requires_backends,
|
||||||
|
)
|
||||||
|
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..image_utils import load_image
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_VISION_2_SEQ_MAPPING
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from ..models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||||
|
class Image2TextGenerationPipeline(Pipeline):
|
||||||
|
"""
|
||||||
|
Image2Text Generation pipeline using a `AutoModelForVision2Seq`. This pipeline predicts a caption for a given
|
||||||
|
image.
|
||||||
|
|
||||||
|
This image to text generation pipeline can currently be loaded from pipeline() using the following task identifier:
|
||||||
|
"image2text-generation".
|
||||||
|
|
||||||
|
See the list of available models on
|
||||||
|
[huggingface.co/models](https://huggingface.co/models?pipeline_tag=image-to-text).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
requires_backends(self, "vision")
|
||||||
|
self.check_model_type(
|
||||||
|
TF_MODEL_FOR_VISION_2_SEQ_MAPPING if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sanitize_parameters(self, **kwargs):
|
||||||
|
return {}, {}, {}
|
||||||
|
|
||||||
|
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
|
||||||
|
"""
|
||||||
|
Assign labels to the image(s) passed as inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||||
|
The pipeline handles three types of images:
|
||||||
|
|
||||||
|
- A string containing a HTTP(s) link pointing to an image
|
||||||
|
- A string containing a local path to an image
|
||||||
|
- An image loaded in PIL directly
|
||||||
|
|
||||||
|
The pipeline accepts either a single image or a batch of images.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A list or a list of list of `dict`: Each result comes as a dictionary with the following key:
|
||||||
|
|
||||||
|
- **generated_text** (`str`) -- The generated text.
|
||||||
|
"""
|
||||||
|
return super().__call__(images, **kwargs)
|
||||||
|
|
||||||
|
def preprocess(self, image):
|
||||||
|
image = load_image(image)
|
||||||
|
model_inputs = self.feature_extractor(images=image, return_tensors=self.framework)
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
def _forward(self, model_inputs):
|
||||||
|
# FIXME: We need to pop here due to a difference in how `generation_utils.py` and `generation_tf_utils.py`
|
||||||
|
# parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas
|
||||||
|
# the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name`
|
||||||
|
# in the `_prepare_model_inputs` method.
|
||||||
|
inputs = model_inputs.pop(self.model.main_input_name)
|
||||||
|
model_outputs = self.model.generate(inputs, **model_inputs)
|
||||||
|
return model_outputs
|
||||||
|
|
||||||
|
def postprocess(self, model_outputs):
|
||||||
|
records = []
|
||||||
|
for output_ids in model_outputs:
|
||||||
|
record = {
|
||||||
|
"generated_text": self.tokenizer.decode(
|
||||||
|
output_ids,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
records.append(record)
|
||||||
|
return records
|
||||||
171
tests/pipelines/test_pipelines_image2text_generation.py
Normal file
171
tests/pipelines/test_pipelines_image2text_generation.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import MODEL_FOR_VISION_2_SEQ_MAPPING, TF_MODEL_FOR_VISION_2_SEQ_MAPPING, is_vision_available
|
||||||
|
from transformers.pipelines import pipeline
|
||||||
|
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, require_vision, slow
|
||||||
|
|
||||||
|
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
else:
|
||||||
|
|
||||||
|
class Image:
|
||||||
|
@staticmethod
|
||||||
|
def open(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@is_pipeline_test
|
||||||
|
@require_vision
|
||||||
|
class Image2TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||||
|
model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
|
||||||
|
tf_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
|
||||||
|
|
||||||
|
def get_test_pipeline(self, model, tokenizer, feature_extractor):
|
||||||
|
pipe = pipeline("image2text-generation", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||||
|
examples = [
|
||||||
|
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
|
||||||
|
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||||
|
]
|
||||||
|
return pipe, examples
|
||||||
|
|
||||||
|
def run_pipeline_test(self, pipe, examples):
|
||||||
|
outputs = pipe(examples)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
[{"generated_text": ANY(str)}],
|
||||||
|
[{"generated_text": ANY(str)}],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_small_model_tf(self):
|
||||||
|
pipe = pipeline("image2text-generation", model="hf-internal-testing/tiny-random-vit-gpt2")
|
||||||
|
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||||
|
|
||||||
|
outputs = pipe(image)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"generated_text": (
|
||||||
|
" intermedi intermedi intermedi intermedi intermedi "
|
||||||
|
"explorer explorer explorer explorer explorer explorer "
|
||||||
|
"explorer medicine medicine medicine medicine medicine "
|
||||||
|
"medicine medicine"
|
||||||
|
)
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = pipe([image, image])
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"generated_text": (
|
||||||
|
" intermedi intermedi intermedi intermedi intermedi "
|
||||||
|
"explorer explorer explorer explorer explorer explorer "
|
||||||
|
"explorer medicine medicine medicine medicine medicine "
|
||||||
|
"medicine medicine"
|
||||||
|
)
|
||||||
|
},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"generated_text": (
|
||||||
|
" intermedi intermedi intermedi intermedi intermedi "
|
||||||
|
"explorer explorer explorer explorer explorer explorer "
|
||||||
|
"explorer medicine medicine medicine medicine medicine "
|
||||||
|
"medicine medicine"
|
||||||
|
)
|
||||||
|
},
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_pt(self):
|
||||||
|
pipe = pipeline("image2text-generation", model="hf-internal-testing/tiny-random-vit-gpt2")
|
||||||
|
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||||
|
|
||||||
|
outputs = pipe(image)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"generated_text": "growthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthGOGO"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = pipe([image, image])
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"generated_text": "growthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthGOGO"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"generated_text": "growthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthgrowthGOGO"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_large_model_pt(self):
|
||||||
|
pipe = pipeline("image2text-generation", model="ydshieh/vit-gpt2-coco-en")
|
||||||
|
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||||
|
|
||||||
|
outputs = pipe(image)
|
||||||
|
self.assertEqual(outputs, [{"generated_text": "a cat laying on a blanket next to a cat laying on a bed "}])
|
||||||
|
|
||||||
|
outputs = pipe([image, image])
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
[{"generated_text": "a cat laying on a blanket next to a cat laying on a bed "}],
|
||||||
|
[{"generated_text": "a cat laying on a blanket next to a cat laying on a bed "}],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_tf
|
||||||
|
def test_large_model_tf(self):
|
||||||
|
pipe = pipeline("image2text-generation", model="ydshieh/vit-gpt2-coco-en")
|
||||||
|
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||||
|
|
||||||
|
outputs = pipe(image)
|
||||||
|
self.assertEqual(outputs, [{"generated_text": "a cat laying on a blanket next to a cat laying on a bed "}])
|
||||||
|
|
||||||
|
outputs = pipe([image, image])
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
[{"generated_text": "a cat laying on a blanket next to a cat laying on a bed "}],
|
||||||
|
[{"generated_text": "a cat laying on a blanket next to a cat laying on a bed "}],
|
||||||
|
],
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user