Enabling TF on image-classification pipeline. (#15030)

This commit is contained in:
Nicolas Patry
2022-01-06 14:16:00 +01:00
committed by GitHub
parent 9f89fa02ed
commit 5a06118b39
2 changed files with 70 additions and 13 deletions

View File

@@ -14,7 +14,12 @@
import unittest
from transformers import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, PreTrainedTokenizer, is_vision_available
from transformers import (
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
PreTrainedTokenizer,
is_vision_available,
)
from transformers.pipelines import ImageClassificationPipeline, pipeline
from transformers.testing_utils import (
is_pipeline_test,
@@ -40,9 +45,9 @@ else:
@is_pipeline_test
@require_vision
@require_torch
class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
tf_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
def get_test_pipeline(self, model, tokenizer, feature_extractor):
image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor, top_k=2)
@@ -145,9 +150,42 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
)
@require_tf
@unittest.skip("Image classification is not implemented for TF")
def test_small_model_tf(self):
pass
small_model = "lysandre/tiny-vit-random"
image_classifier = pipeline("image-classification", model=small_model)
outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg")
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"},
{"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"},
{"score": 0.0014, "label": "trench coat"},
{"score": 0.0014, "label": "handkerchief, hankie, hanky, hankey"},
{"score": 0.0014, "label": "baboon"},
],
)
outputs = image_classifier(
[
"http://images.cocodataset.org/val2017/000000039769.jpg",
"http://images.cocodataset.org/val2017/000000039769.jpg",
],
top_k=2,
)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
[
{"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"},
{"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"},
],
[
{"score": 0.0015, "label": "chambered nautilus, pearly nautilus, nautilus"},
{"score": 0.0015, "label": "pajama, pyjama, pj's, jammies"},
],
],
)
def test_custom_tokenizer(self):
tokenizer = PreTrainedTokenizer()