Pipeline update & tests (#12207)
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoFeatureExtractor,
|
||||
AutoModelForImageClassification,
|
||||
PreTrainedTokenizer,
|
||||
@@ -128,3 +129,33 @@ class ImageClassificationPipelineTests(unittest.TestCase):
|
||||
image_classifier = pipeline("image-classification", model=self.small_models[0], tokenizer=tokenizer)
|
||||
|
||||
self.assertIs(image_classifier.tokenizer, tokenizer)
|
||||
|
||||
def test_num_labels_inferior_to_topk(self):
|
||||
for small_model in self.small_models:
|
||||
|
||||
num_labels = 2
|
||||
model = AutoModelForImageClassification.from_config(
|
||||
AutoConfig.from_pretrained(small_model, num_labels=num_labels)
|
||||
)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(small_model)
|
||||
image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor)
|
||||
|
||||
for valid_input in self.valid_inputs:
|
||||
output = image_classifier(**valid_input)
|
||||
|
||||
def assert_valid_pipeline_output(pipeline_output):
|
||||
self.assertTrue(isinstance(pipeline_output, list))
|
||||
self.assertEqual(len(pipeline_output), num_labels)
|
||||
for label_result in pipeline_output:
|
||||
self.assertTrue(isinstance(label_result, dict))
|
||||
self.assertIn("label", label_result)
|
||||
self.assertIn("score", label_result)
|
||||
|
||||
if isinstance(valid_input["images"], list):
|
||||
# When images are batched, pipeline output is a list of lists of dictionaries
|
||||
self.assertEqual(len(valid_input["images"]), len(output))
|
||||
for individual_output in output:
|
||||
assert_valid_pipeline_output(individual_output)
|
||||
else:
|
||||
# When images are batched, pipeline output is a list of dictionaries
|
||||
assert_valid_pipeline_output(output)
|
||||
|
||||
Reference in New Issue
Block a user