Enabling TF on image-classification pipeline. (#15030)
This commit is contained in:
@@ -1,6 +1,12 @@
|
||||
from typing import List, Union
|
||||
|
||||
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
|
||||
from ..file_utils import (
|
||||
add_end_docstrings,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
requires_backends,
|
||||
)
|
||||
from ..utils import logging
|
||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
|
||||
@@ -10,6 +16,11 @@ if is_vision_available():
|
||||
|
||||
from ..image_utils import load_image
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||
|
||||
if is_torch_available():
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||
|
||||
@@ -31,12 +42,12 @@ class ImageClassificationPipeline(Pipeline):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if self.framework == "tf":
|
||||
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
|
||||
|
||||
requires_backends(self, "vision")
|
||||
self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)
|
||||
self.check_model_type(
|
||||
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||
if self.framework == "tf"
|
||||
else MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||
)
|
||||
|
||||
def _sanitize_parameters(self, top_k=None):
|
||||
postprocess_params = {}
|
||||
@@ -77,7 +88,7 @@ class ImageClassificationPipeline(Pipeline):
|
||||
|
||||
def preprocess(self, image):
|
||||
image = load_image(image)
|
||||
model_inputs = self.feature_extractor(images=image, return_tensors="pt")
|
||||
model_inputs = self.feature_extractor(images=image, return_tensors=self.framework)
|
||||
return model_inputs
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
@@ -87,8 +98,16 @@ class ImageClassificationPipeline(Pipeline):
|
||||
def postprocess(self, model_outputs, top_k=5):
|
||||
if top_k > self.model.config.num_labels:
|
||||
top_k = self.model.config.num_labels
|
||||
|
||||
if self.framework == "pt":
|
||||
probs = model_outputs.logits.softmax(-1)[0]
|
||||
scores, ids = probs.topk(top_k)
|
||||
elif self.framework == "tf":
|
||||
probs = tf.nn.softmax(model_outputs.logits, axis=-1)[0]
|
||||
topk = tf.math.top_k(probs, k=top_k)
|
||||
scores, ids = topk.values.numpy(), topk.indices.numpy()
|
||||
else:
|
||||
raise ValueError(f"Unsupported framework: {self.framework}")
|
||||
|
||||
scores = scores.tolist()
|
||||
ids = ids.tolist()
|
||||
|
||||
@@ -14,7 +14,12 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, PreTrainedTokenizer, is_vision_available
|
||||
from transformers import (
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
PreTrainedTokenizer,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.pipelines import ImageClassificationPipeline, pipeline
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
@@ -40,9 +45,9 @@ else:
|
||||
|
||||
@is_pipeline_test
|
||||
@require_vision
|
||||
@require_torch
|
||||
class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor):
|
||||
image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor, top_k=2)
|
||||
@@ -145,9 +150,42 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
)
|
||||
|
||||
@require_tf
|
||||
@unittest.skip("Image classification is not implemented for TF")
|
||||
def test_small_model_tf(self):
|
||||
pass
|
||||
small_model = "lysandre/tiny-vit-random"
|
||||
image_classifier = pipeline("image-classification", model=small_model)
|
||||
|
||||
outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg")
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"},
|
||||
{"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"},
|
||||
{"score": 0.0014, "label": "trench coat"},
|
||||
{"score": 0.0014, "label": "handkerchief, hankie, hanky, hankey"},
|
||||
{"score": 0.0014, "label": "baboon"},
|
||||
],
|
||||
)
|
||||
|
||||
outputs = image_classifier(
|
||||
[
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
],
|
||||
top_k=2,
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[
|
||||
{"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"},
|
||||
{"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"},
|
||||
],
|
||||
[
|
||||
{"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"},
|
||||
{"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"},
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
def test_custom_tokenizer(self):
|
||||
tokenizer = PreTrainedTokenizer()
|
||||
|
||||
Reference in New Issue
Block a user