Enable multi-label image classification in pipeline (#28433)
Enable multi-label image classification
This commit is contained in:
@@ -221,3 +221,49 @@ class ImageClassificationPipelineTests(unittest.TestCase):
|
||||
{"score": 0.0096, "label": "quilt, comforter, comfort, puff"},
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_multilabel_classification(self):
|
||||
small_model = "hf-internal-testing/tiny-random-vit"
|
||||
|
||||
# Sigmoid is applied for multi-label classification
|
||||
image_classifier = pipeline("image-classification", model=small_model)
|
||||
image_classifier.model.config.problem_type = "multi_label_classification"
|
||||
|
||||
outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg")
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[{"label": "LABEL_1", "score": 0.5356}, {"label": "LABEL_0", "score": 0.4612}],
|
||||
)
|
||||
|
||||
outputs = image_classifier(
|
||||
[
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
]
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[{"label": "LABEL_1", "score": 0.5356}, {"label": "LABEL_0", "score": 0.4612}],
|
||||
[{"label": "LABEL_1", "score": 0.5356}, {"label": "LABEL_0", "score": 0.4612}],
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_function_to_apply(self):
|
||||
small_model = "hf-internal-testing/tiny-random-vit"
|
||||
|
||||
# Sigmoid is applied for multi-label classification
|
||||
image_classifier = pipeline("image-classification", model=small_model)
|
||||
|
||||
outputs = image_classifier(
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
function_to_apply="sigmoid",
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[{"label": "LABEL_1", "score": 0.5356}, {"label": "LABEL_0", "score": 0.4612}],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user