Enabling TF on image-classification pipeline. (#15030)
This commit is contained in:
@@ -1,6 +1,12 @@
|
|||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
|
from ..file_utils import (
|
||||||
|
add_end_docstrings,
|
||||||
|
is_tf_available,
|
||||||
|
is_torch_available,
|
||||||
|
is_vision_available,
|
||||||
|
requires_backends,
|
||||||
|
)
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||||
|
|
||||||
@@ -10,6 +16,11 @@ if is_vision_available():
|
|||||||
|
|
||||||
from ..image_utils import load_image
|
from ..image_utils import load_image
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||||
|
|
||||||
@@ -31,12 +42,12 @@ class ImageClassificationPipeline(Pipeline):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
if self.framework == "tf":
|
|
||||||
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
|
|
||||||
|
|
||||||
requires_backends(self, "vision")
|
requires_backends(self, "vision")
|
||||||
self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)
|
self.check_model_type(
|
||||||
|
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||||
|
if self.framework == "tf"
|
||||||
|
else MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||||
|
)
|
||||||
|
|
||||||
def _sanitize_parameters(self, top_k=None):
|
def _sanitize_parameters(self, top_k=None):
|
||||||
postprocess_params = {}
|
postprocess_params = {}
|
||||||
@@ -77,7 +88,7 @@ class ImageClassificationPipeline(Pipeline):
|
|||||||
|
|
||||||
def preprocess(self, image):
|
def preprocess(self, image):
|
||||||
image = load_image(image)
|
image = load_image(image)
|
||||||
model_inputs = self.feature_extractor(images=image, return_tensors="pt")
|
model_inputs = self.feature_extractor(images=image, return_tensors=self.framework)
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def _forward(self, model_inputs):
|
def _forward(self, model_inputs):
|
||||||
@@ -87,8 +98,16 @@ class ImageClassificationPipeline(Pipeline):
|
|||||||
def postprocess(self, model_outputs, top_k=5):
|
def postprocess(self, model_outputs, top_k=5):
|
||||||
if top_k > self.model.config.num_labels:
|
if top_k > self.model.config.num_labels:
|
||||||
top_k = self.model.config.num_labels
|
top_k = self.model.config.num_labels
|
||||||
|
|
||||||
|
if self.framework == "pt":
|
||||||
probs = model_outputs.logits.softmax(-1)[0]
|
probs = model_outputs.logits.softmax(-1)[0]
|
||||||
scores, ids = probs.topk(top_k)
|
scores, ids = probs.topk(top_k)
|
||||||
|
elif self.framework == "tf":
|
||||||
|
probs = tf.nn.softmax(model_outputs.logits, axis=-1)[0]
|
||||||
|
topk = tf.math.top_k(probs, k=top_k)
|
||||||
|
scores, ids = topk.values.numpy(), topk.indices.numpy()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported framework: {self.framework}")
|
||||||
|
|
||||||
scores = scores.tolist()
|
scores = scores.tolist()
|
||||||
ids = ids.tolist()
|
ids = ids.tolist()
|
||||||
|
|||||||
@@ -14,7 +14,12 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, PreTrainedTokenizer, is_vision_available
|
from transformers import (
|
||||||
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||||
|
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||||
|
PreTrainedTokenizer,
|
||||||
|
is_vision_available,
|
||||||
|
)
|
||||||
from transformers.pipelines import ImageClassificationPipeline, pipeline
|
from transformers.pipelines import ImageClassificationPipeline, pipeline
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_pipeline_test,
|
is_pipeline_test,
|
||||||
@@ -40,9 +45,9 @@ else:
|
|||||||
|
|
||||||
@is_pipeline_test
|
@is_pipeline_test
|
||||||
@require_vision
|
@require_vision
|
||||||
@require_torch
|
|
||||||
class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||||
model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||||
|
tf_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||||
|
|
||||||
def get_test_pipeline(self, model, tokenizer, feature_extractor):
|
def get_test_pipeline(self, model, tokenizer, feature_extractor):
|
||||||
image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor, top_k=2)
|
image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor, top_k=2)
|
||||||
@@ -145,9 +150,42 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
|||||||
)
|
)
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@unittest.skip("Image classification is not implemented for TF")
|
|
||||||
def test_small_model_tf(self):
|
def test_small_model_tf(self):
|
||||||
pass
|
small_model = "lysandre/tiny-vit-random"
|
||||||
|
image_classifier = pipeline("image-classification", model=small_model)
|
||||||
|
|
||||||
|
outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg")
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
{"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"},
|
||||||
|
{"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"},
|
||||||
|
{"score": 0.0014, "label": "trench coat"},
|
||||||
|
{"score": 0.0014, "label": "handkerchief, hankie, hanky, hankey"},
|
||||||
|
{"score": 0.0014, "label": "baboon"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = image_classifier(
|
||||||
|
[
|
||||||
|
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
|
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
|
],
|
||||||
|
top_k=2,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=4),
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"},
|
||||||
|
{"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"},
|
||||||
|
{"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"},
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
def test_custom_tokenizer(self):
|
def test_custom_tokenizer(self):
|
||||||
tokenizer = PreTrainedTokenizer()
|
tokenizer = PreTrainedTokenizer()
|
||||||
|
|||||||
Reference in New Issue
Block a user