Enabling TF on image-classification pipeline. (#15030)
This commit is contained in:
@@ -14,7 +14,12 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, PreTrainedTokenizer, is_vision_available
|
||||
from transformers import (
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
PreTrainedTokenizer,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.pipelines import ImageClassificationPipeline, pipeline
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
@@ -40,9 +45,9 @@ else:
|
||||
|
||||
@is_pipeline_test
|
||||
@require_vision
|
||||
@require_torch
|
||||
class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor):
|
||||
image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor, top_k=2)
|
||||
@@ -145,9 +150,42 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
)
|
||||
|
||||
@require_tf
|
||||
@unittest.skip("Image classification is not implemented for TF")
|
||||
def test_small_model_tf(self):
|
||||
pass
|
||||
small_model = "lysandre/tiny-vit-random"
|
||||
image_classifier = pipeline("image-classification", model=small_model)
|
||||
|
||||
outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg")
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"},
|
||||
{"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"},
|
||||
{"score": 0.0014, "label": "trench coat"},
|
||||
{"score": 0.0014, "label": "handkerchief, hankie, hanky, hankey"},
|
||||
{"score": 0.0014, "label": "baboon"},
|
||||
],
|
||||
)
|
||||
|
||||
outputs = image_classifier(
|
||||
[
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
],
|
||||
top_k=2,
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[
|
||||
{"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"},
|
||||
{"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"},
|
||||
],
|
||||
[
|
||||
{"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"},
|
||||
{"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"},
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
def test_custom_tokenizer(self):
|
||||
tokenizer = PreTrainedTokenizer()
|
||||
|
||||
Reference in New Issue
Block a user