Add Visual Question Answering (VQA) pipeline (#17286)
* wip * rebase * all tests pass * rebase * ready for PR * address comments * fix styles * add require_torch to pipeline test * remove remote image to improve CI consistency * address comments; fix tf/flax tests * address comments; fix tf/flax tests * fix tests; add alias * repo consistency tests * Update src/transformers/pipelines/visual_question_answering.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * address comments * Update src/transformers/pipelines/visual_question_answering.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * merge * Update src/transformers/models/auto/modeling_auto.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * merge Co-authored-by: Sijun He <sijunhe@Sijuns-MacBook-Pro.local> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -38,6 +38,7 @@ There are two categories of pipeline abstractions to be aware about:
|
|||||||
- [`Text2TextGenerationPipeline`]
|
- [`Text2TextGenerationPipeline`]
|
||||||
- [`TokenClassificationPipeline`]
|
- [`TokenClassificationPipeline`]
|
||||||
- [`TranslationPipeline`]
|
- [`TranslationPipeline`]
|
||||||
|
- [`VisualQuestionAnsweringPipeline`]
|
||||||
- [`ZeroShotClassificationPipeline`]
|
- [`ZeroShotClassificationPipeline`]
|
||||||
- [`ZeroShotImageClassificationPipeline`]
|
- [`ZeroShotImageClassificationPipeline`]
|
||||||
|
|
||||||
@@ -423,6 +424,12 @@ See [`TokenClassificationPipeline`] for all details.
|
|||||||
- __call__
|
- __call__
|
||||||
- all
|
- all
|
||||||
|
|
||||||
|
### VisualQuestionAnsweringPipeline
|
||||||
|
|
||||||
|
[[autodoc]] VisualQuestionAnsweringPipeline
|
||||||
|
- __call__
|
||||||
|
- all
|
||||||
|
|
||||||
### ZeroShotClassificationPipeline
|
### ZeroShotClassificationPipeline
|
||||||
|
|
||||||
[[autodoc]] ZeroShotClassificationPipeline
|
[[autodoc]] ZeroShotClassificationPipeline
|
||||||
|
|||||||
@@ -122,6 +122,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its
|
|||||||
|
|
||||||
[[autodoc]] AutoModelForVision2Seq
|
[[autodoc]] AutoModelForVision2Seq
|
||||||
|
|
||||||
|
## AutoModelForVisualQuestionAnswering
|
||||||
|
|
||||||
|
[[autodoc]] AutoModelForVisualQuestionAnswering
|
||||||
|
|
||||||
## AutoModelForAudioClassification
|
## AutoModelForAudioClassification
|
||||||
|
|
||||||
[[autodoc]] AutoModelForAudioClassification
|
[[autodoc]] AutoModelForAudioClassification
|
||||||
|
|||||||
@@ -377,6 +377,7 @@ _import_structure = {
|
|||||||
"TextGenerationPipeline",
|
"TextGenerationPipeline",
|
||||||
"TokenClassificationPipeline",
|
"TokenClassificationPipeline",
|
||||||
"TranslationPipeline",
|
"TranslationPipeline",
|
||||||
|
"VisualQuestionAnsweringPipeline",
|
||||||
"ZeroShotClassificationPipeline",
|
"ZeroShotClassificationPipeline",
|
||||||
"ZeroShotImageClassificationPipeline",
|
"ZeroShotImageClassificationPipeline",
|
||||||
"pipeline",
|
"pipeline",
|
||||||
@@ -758,6 +759,7 @@ else:
|
|||||||
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
||||||
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||||
"MODEL_FOR_VISION_2_SEQ_MAPPING",
|
"MODEL_FOR_VISION_2_SEQ_MAPPING",
|
||||||
|
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
|
||||||
"MODEL_MAPPING",
|
"MODEL_MAPPING",
|
||||||
"MODEL_WITH_LM_HEAD_MAPPING",
|
"MODEL_WITH_LM_HEAD_MAPPING",
|
||||||
"AutoModel",
|
"AutoModel",
|
||||||
@@ -783,6 +785,7 @@ else:
|
|||||||
"AutoModelForTableQuestionAnswering",
|
"AutoModelForTableQuestionAnswering",
|
||||||
"AutoModelForTokenClassification",
|
"AutoModelForTokenClassification",
|
||||||
"AutoModelForVision2Seq",
|
"AutoModelForVision2Seq",
|
||||||
|
"AutoModelForVisualQuestionAnswering",
|
||||||
"AutoModelWithLMHead",
|
"AutoModelWithLMHead",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -2961,6 +2964,7 @@ if TYPE_CHECKING:
|
|||||||
TextGenerationPipeline,
|
TextGenerationPipeline,
|
||||||
TokenClassificationPipeline,
|
TokenClassificationPipeline,
|
||||||
TranslationPipeline,
|
TranslationPipeline,
|
||||||
|
VisualQuestionAnsweringPipeline,
|
||||||
ZeroShotClassificationPipeline,
|
ZeroShotClassificationPipeline,
|
||||||
ZeroShotImageClassificationPipeline,
|
ZeroShotImageClassificationPipeline,
|
||||||
pipeline,
|
pipeline,
|
||||||
@@ -3291,6 +3295,7 @@ if TYPE_CHECKING:
|
|||||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||||
|
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
MODEL_WITH_LM_HEAD_MAPPING,
|
MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
@@ -3316,6 +3321,7 @@ if TYPE_CHECKING:
|
|||||||
AutoModelForTableQuestionAnswering,
|
AutoModelForTableQuestionAnswering,
|
||||||
AutoModelForTokenClassification,
|
AutoModelForTokenClassification,
|
||||||
AutoModelForVision2Seq,
|
AutoModelForVision2Seq,
|
||||||
|
AutoModelForVisualQuestionAnswering,
|
||||||
AutoModelWithLMHead,
|
AutoModelWithLMHead,
|
||||||
)
|
)
|
||||||
from .models.bart import (
|
from .models.bart import (
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ else:
|
|||||||
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
||||||
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||||
"MODEL_FOR_VISION_2_SEQ_MAPPING",
|
"MODEL_FOR_VISION_2_SEQ_MAPPING",
|
||||||
|
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
|
||||||
"MODEL_MAPPING",
|
"MODEL_MAPPING",
|
||||||
"MODEL_WITH_LM_HEAD_MAPPING",
|
"MODEL_WITH_LM_HEAD_MAPPING",
|
||||||
"AutoModel",
|
"AutoModel",
|
||||||
@@ -89,6 +90,7 @@ else:
|
|||||||
"AutoModelForTableQuestionAnswering",
|
"AutoModelForTableQuestionAnswering",
|
||||||
"AutoModelForTokenClassification",
|
"AutoModelForTokenClassification",
|
||||||
"AutoModelForVision2Seq",
|
"AutoModelForVision2Seq",
|
||||||
|
"AutoModelForVisualQuestionAnswering",
|
||||||
"AutoModelWithLMHead",
|
"AutoModelWithLMHead",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -202,6 +204,7 @@ if TYPE_CHECKING:
|
|||||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||||
|
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
MODEL_WITH_LM_HEAD_MAPPING,
|
MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
@@ -227,6 +230,7 @@ if TYPE_CHECKING:
|
|||||||
AutoModelForTableQuestionAnswering,
|
AutoModelForTableQuestionAnswering,
|
||||||
AutoModelForTokenClassification,
|
AutoModelForTokenClassification,
|
||||||
AutoModelForVision2Seq,
|
AutoModelForVision2Seq,
|
||||||
|
AutoModelForVisualQuestionAnswering,
|
||||||
AutoModelWithLMHead,
|
AutoModelWithLMHead,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
|||||||
("speech_to_text", "Speech2TextFeatureExtractor"),
|
("speech_to_text", "Speech2TextFeatureExtractor"),
|
||||||
("swin", "ViTFeatureExtractor"),
|
("swin", "ViTFeatureExtractor"),
|
||||||
("van", "ConvNextFeatureExtractor"),
|
("van", "ConvNextFeatureExtractor"),
|
||||||
|
("vilt", "ViltFeatureExtractor"),
|
||||||
("vit", "ViTFeatureExtractor"),
|
("vit", "ViTFeatureExtractor"),
|
||||||
("vit_mae", "ViTFeatureExtractor"),
|
("vit_mae", "ViTFeatureExtractor"),
|
||||||
("wav2vec2", "Wav2Vec2FeatureExtractor"),
|
("wav2vec2", "Wav2Vec2FeatureExtractor"),
|
||||||
|
|||||||
@@ -548,6 +548,12 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||||
|
[
|
||||||
|
("vilt", "ViltForQuestionAnswering"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
# Model for Token Classification mapping
|
# Model for Token Classification mapping
|
||||||
@@ -706,6 +712,9 @@ MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = _LazyAutoMapping(
|
|||||||
CONFIG_MAPPING_NAMES, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES
|
CONFIG_MAPPING_NAMES, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES
|
||||||
)
|
)
|
||||||
MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
|
MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
|
||||||
|
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
|
||||||
|
CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES
|
||||||
|
)
|
||||||
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
|
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
|
||||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
|
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
|
||||||
CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
|
CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
|
||||||
@@ -813,6 +822,17 @@ AutoModelForTableQuestionAnswering = auto_class_update(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForVisualQuestionAnswering(_BaseAutoModelClass):
|
||||||
|
_model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
|
||||||
|
|
||||||
|
|
||||||
|
AutoModelForVisualQuestionAnswering = auto_class_update(
|
||||||
|
AutoModelForVisualQuestionAnswering,
|
||||||
|
head_doc="visual question answering",
|
||||||
|
checkpoint_for_example="dandelin/vilt-b32-finetuned-vqa",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AutoModelForTokenClassification(_BaseAutoModelClass):
|
class AutoModelForTokenClassification(_BaseAutoModelClass):
|
||||||
_model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
_model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
||||||
|
|
||||||
|
|||||||
@@ -229,6 +229,7 @@ else:
|
|||||||
("tapas", ("TapasTokenizer", None)),
|
("tapas", ("TapasTokenizer", None)),
|
||||||
("tapex", ("TapexTokenizer", None)),
|
("tapex", ("TapexTokenizer", None)),
|
||||||
("transfo-xl", ("TransfoXLTokenizer", None)),
|
("transfo-xl", ("TransfoXLTokenizer", None)),
|
||||||
|
("vilt", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
|
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
|
||||||
("wav2vec2-conformer", ("Wav2Vec2CTCTokenizer", None)),
|
("wav2vec2-conformer", ("Wav2Vec2CTCTokenizer", None)),
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ from .token_classification import (
|
|||||||
TokenClassificationArgumentHandler,
|
TokenClassificationArgumentHandler,
|
||||||
TokenClassificationPipeline,
|
TokenClassificationPipeline,
|
||||||
)
|
)
|
||||||
|
from .visual_question_answering import VisualQuestionAnsweringPipeline
|
||||||
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline
|
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline
|
||||||
from .zero_shot_image_classification import ZeroShotImageClassificationPipeline
|
from .zero_shot_image_classification import ZeroShotImageClassificationPipeline
|
||||||
|
|
||||||
@@ -94,6 +95,7 @@ if is_torch_available():
|
|||||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
|
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForAudioClassification,
|
AutoModelForAudioClassification,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
@@ -109,6 +111,7 @@ if is_torch_available():
|
|||||||
AutoModelForSpeechSeq2Seq,
|
AutoModelForSpeechSeq2Seq,
|
||||||
AutoModelForTableQuestionAnswering,
|
AutoModelForTableQuestionAnswering,
|
||||||
AutoModelForTokenClassification,
|
AutoModelForTokenClassification,
|
||||||
|
AutoModelForVisualQuestionAnswering,
|
||||||
)
|
)
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..modeling_tf_utils import TFPreTrainedModel
|
from ..modeling_tf_utils import TFPreTrainedModel
|
||||||
@@ -121,6 +124,7 @@ logger = logging.get_logger(__name__)
|
|||||||
TASK_ALIASES = {
|
TASK_ALIASES = {
|
||||||
"sentiment-analysis": "text-classification",
|
"sentiment-analysis": "text-classification",
|
||||||
"ner": "token-classification",
|
"ner": "token-classification",
|
||||||
|
"vqa": "visual-question-answering",
|
||||||
}
|
}
|
||||||
SUPPORTED_TASKS = {
|
SUPPORTED_TASKS = {
|
||||||
"audio-classification": {
|
"audio-classification": {
|
||||||
@@ -190,6 +194,19 @@ SUPPORTED_TASKS = {
|
|||||||
},
|
},
|
||||||
"type": "text",
|
"type": "text",
|
||||||
},
|
},
|
||||||
|
"visual-question-answering": {
|
||||||
|
"impl": VisualQuestionAnsweringPipeline,
|
||||||
|
"pt": (AutoModelForVisualQuestionAnswering,) if is_torch_available() else (),
|
||||||
|
"tf": (),
|
||||||
|
"default": {
|
||||||
|
"model": {
|
||||||
|
"pt": "dandelin/vilt-b32-finetuned-vqa",
|
||||||
|
"tokenizer": "dandelin/vilt-b32-finetuned-vqa",
|
||||||
|
"feature_extractor": "dandelin/vilt-b32-finetuned-vqa",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"type": "multimodal",
|
||||||
|
},
|
||||||
"fill-mask": {
|
"fill-mask": {
|
||||||
"impl": FillMaskPipeline,
|
"impl": FillMaskPipeline,
|
||||||
"tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
|
"tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
|
||||||
|
|||||||
115
src/transformers/pipelines/visual_question_answering.py
Normal file
115
src/transformers/pipelines/visual_question_answering.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging
|
||||||
|
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..image_utils import load_image
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from ..models.auto.modeling_auto import MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||||
|
class VisualQuestionAnsweringPipeline(Pipeline):
|
||||||
|
"""
|
||||||
|
Visual Question Answering pipeline using a `AutoModelForVisualQuestionAnswering`. This pipeline is currently only
|
||||||
|
available in PyTorch.
|
||||||
|
|
||||||
|
This visual question answering pipeline can currently be loaded from [`pipeline`] using the following task
|
||||||
|
identifiers: `"visual-question-answering", "vqa"`.
|
||||||
|
|
||||||
|
The models that this pipeline can use are models that have been fine-tuned on a visual question answering task. See
|
||||||
|
the up-to-date list of available models on
|
||||||
|
[huggingface.co/models](https://huggingface.co/models?filter=visual-question-answering).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.check_model_type(MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING)
|
||||||
|
|
||||||
|
def _sanitize_parameters(self, top_k=None, padding=None, truncation=None, **kwargs):
|
||||||
|
preprocess_params, postprocess_params = {}, {}
|
||||||
|
if padding is not None:
|
||||||
|
preprocess_params["padding"] = padding
|
||||||
|
if truncation is not None:
|
||||||
|
preprocess_params["truncation"] = truncation
|
||||||
|
if top_k is not None:
|
||||||
|
postprocess_params["top_k"] = top_k
|
||||||
|
return preprocess_params, {}, postprocess_params
|
||||||
|
|
||||||
|
def __call__(self, image: Union["Image.Image", str], question: str = None, **kwargs):
|
||||||
|
r"""
|
||||||
|
Answers open-ended questions about images. The pipeline accepts several types of inputs which are detailed
|
||||||
|
below:
|
||||||
|
|
||||||
|
- `pipeline(image=image, question=question)`
|
||||||
|
- `pipeline({"image": image, "question": question})`
|
||||||
|
- `pipeline([{"image": image, "question": question}])`
|
||||||
|
- `pipeline([{"image": image, "question": question}, {"image": image, "question": question}])`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||||
|
The pipeline handles three types of images:
|
||||||
|
|
||||||
|
- A string containing a http link pointing to an image
|
||||||
|
- A string containing a local path to an image
|
||||||
|
- An image loaded in PIL directly
|
||||||
|
|
||||||
|
The pipeline accepts either a single image or a batch of images. If given a single image, it can be
|
||||||
|
broadcasted to multiple questions.
|
||||||
|
question (`str`, `List[str]`):
|
||||||
|
The question(s) asked. If given a single question, it can be broadcasted to multiple images.
|
||||||
|
top_k (`int`, *optional*, defaults to 5):
|
||||||
|
The number of top labels that will be returned by the pipeline. If the provided number is higher than
|
||||||
|
the number of labels available in the model configuration, it will default to the number of labels.
|
||||||
|
Return:
|
||||||
|
A dictionary or a list of dictionaries containing the result. The dictionaries contain the following keys:
|
||||||
|
|
||||||
|
- **label** (`str`) -- The label identified by the model.
|
||||||
|
- **score** (`int`) -- The score attributed by the model for that label.
|
||||||
|
"""
|
||||||
|
if isinstance(image, (Image.Image, str)) and isinstance(question, str):
|
||||||
|
inputs = {"image": image, "question": question}
|
||||||
|
else:
|
||||||
|
"""
|
||||||
|
Supports the following format
|
||||||
|
- {"image": image, "question": question}
|
||||||
|
- [{"image": image, "question": question}]
|
||||||
|
- Generator and datasets
|
||||||
|
"""
|
||||||
|
inputs = image
|
||||||
|
results = super().__call__(inputs, **kwargs)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def preprocess(self, inputs, padding=False, truncation=False):
|
||||||
|
image = load_image(inputs["image"])
|
||||||
|
model_inputs = self.tokenizer(
|
||||||
|
inputs["question"], return_tensors=self.framework, padding=padding, truncation=truncation
|
||||||
|
)
|
||||||
|
image_features = self.feature_extractor(images=image, return_tensors=self.framework)
|
||||||
|
model_inputs.update(image_features)
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
def _forward(self, model_inputs):
|
||||||
|
model_outputs = self.model(**model_inputs)
|
||||||
|
return model_outputs
|
||||||
|
|
||||||
|
def postprocess(self, model_outputs, top_k=5):
|
||||||
|
if top_k > self.model.config.num_labels:
|
||||||
|
top_k = self.model.config.num_labels
|
||||||
|
|
||||||
|
if self.framework == "pt":
|
||||||
|
probs = model_outputs.logits.sigmoid()[0]
|
||||||
|
scores, ids = probs.topk(top_k)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported framework: {self.framework}")
|
||||||
|
|
||||||
|
scores = scores.tolist()
|
||||||
|
ids = ids.tolist()
|
||||||
|
return [{"score": score, "answer": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
|
||||||
@@ -409,6 +409,9 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
|
|||||||
MODEL_FOR_VISION_2_SEQ_MAPPING = None
|
MODEL_FOR_VISION_2_SEQ_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
MODEL_MAPPING = None
|
MODEL_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
@@ -576,6 +579,13 @@ class AutoModelForVision2Seq(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForVisualQuestionAnswering(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class AutoModelWithLMHead(metaclass=DummyObject):
|
class AutoModelWithLMHead(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
115
tests/pipelines/test_pipelines_visual_question_answering.py
Normal file
115
tests/pipelines/test_pipelines_visual_question_answering.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING, is_vision_available
|
||||||
|
from transformers.pipelines import pipeline
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
is_pipeline_test,
|
||||||
|
nested_simplify,
|
||||||
|
require_tf,
|
||||||
|
require_torch,
|
||||||
|
require_vision,
|
||||||
|
slow,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
else:
|
||||||
|
|
||||||
|
class Image:
|
||||||
|
@staticmethod
|
||||||
|
def open(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@is_pipeline_test
|
||||||
|
@require_torch
|
||||||
|
@require_vision
|
||||||
|
class VisualQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||||
|
model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
|
||||||
|
|
||||||
|
def get_test_pipeline(self, model, tokenizer, feature_extractor):
|
||||||
|
vqa_pipeline = pipeline("visual-question-answering", model="hf-internal-testing/tiny-vilt-random-vqa")
|
||||||
|
examples = [
|
||||||
|
{
|
||||||
|
"image": Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
|
||||||
|
"question": "How many cats are there?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"image": "./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||||
|
"question": "How many cats are there?",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return vqa_pipeline, examples
|
||||||
|
|
||||||
|
def run_pipeline_test(self, vqa_pipeline, examples):
|
||||||
|
outputs = vqa_pipeline(examples, top_k=1)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
[{"score": ANY(float), "answer": ANY(str)}],
|
||||||
|
[{"score": ANY(float), "answer": ANY(str)}],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_pt(self):
|
||||||
|
vqa_pipeline = pipeline("visual-question-answering", model="hf-internal-testing/tiny-vilt-random-vqa")
|
||||||
|
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||||
|
question = "How many cats are there?"
|
||||||
|
|
||||||
|
outputs = vqa_pipeline(image=image, question="How many cats are there?", top_k=2)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs, [{"score": ANY(float), "answer": ANY(str)}, {"score": ANY(float), "answer": ANY(str)}]
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = vqa_pipeline({"image": image, "question": question}, top_k=2)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs, [{"score": ANY(float), "answer": ANY(str)}, {"score": ANY(float), "answer": ANY(str)}]
|
||||||
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_large_model_pt(self):
|
||||||
|
vqa_pipeline = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa")
|
||||||
|
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||||
|
question = "How many cats are there?"
|
||||||
|
|
||||||
|
outputs = vqa_pipeline(image=image, question=question, top_k=2)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4), [{"score": 0.8799, "answer": "2"}, {"score": 0.296, "answer": "1"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = vqa_pipeline({"image": image, "question": question}, top_k=2)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4), [{"score": 0.8799, "answer": "2"}, {"score": 0.296, "answer": "1"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = vqa_pipeline(
|
||||||
|
[{"image": image, "question": question}, {"image": image, "question": question}], top_k=2
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[[{"score": 0.8799, "answer": "2"}, {"score": 0.296, "answer": "1"}]] * 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
@unittest.skip("Visual question answering not implemented in TF")
|
||||||
|
def test_small_model_tf(self):
|
||||||
|
pass
|
||||||
Reference in New Issue
Block a user