Improve and fix ImageSegmentationPipeline (#19367)
- Fixes the image segmentation pipeline test failures caused by changes to the postprocessing methods of supported models - Updates the ImageSegmentationPipeline tests - Improves docs, adds 'task' argument to optionally perform semantic, instance or panoptic segmentation
This commit is contained in:
@@ -74,9 +74,6 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
}
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor):
|
||||
# Fix me Alara
|
||||
if model.__class__.__name__ in ["DetrForSegmentation", "MaskFormerForInstanceSegmentation"]:
|
||||
return None, None
|
||||
image_segmenter = ImageSegmentationPipeline(model=model, feature_extractor=feature_extractor)
|
||||
return image_segmenter, [
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
@@ -150,7 +147,7 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
pass
|
||||
|
||||
@require_torch
|
||||
@unittest.skip("Fix me Alara!")
|
||||
@unittest.skip("No weights found for hf-internal-testing/tiny-detr-mobilenetsv3-panoptic")
|
||||
def test_small_model_pt(self):
|
||||
model_id = "hf-internal-testing/tiny-detr-mobilenetsv3-panoptic"
|
||||
|
||||
@@ -158,9 +155,15 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
|
||||
image_segmenter = ImageSegmentationPipeline(model=model, feature_extractor=feature_extractor)
|
||||
|
||||
outputs = image_segmenter("http://images.cocodataset.org/val2017/000000039769.jpg", threshold=0.0)
|
||||
outputs = image_segmenter(
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
task="panoptic",
|
||||
threshold=0.0,
|
||||
overlap_mask_area_threshold=0.0,
|
||||
)
|
||||
|
||||
# Shortening by hashing
|
||||
for o in outputs:
|
||||
# shortening by hashing
|
||||
o["mask"] = hashimage(o["mask"])
|
||||
|
||||
self.assertEqual(
|
||||
@@ -235,12 +238,12 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
{
|
||||
"score": None,
|
||||
"label": "LABEL_0",
|
||||
"mask": "6225140faf502d272af076222776d7e4",
|
||||
"mask": "775518a7ed09eea888752176c6ba8f38",
|
||||
},
|
||||
{
|
||||
"score": None,
|
||||
"label": "LABEL_1",
|
||||
"mask": "8297c9f8eb43ddd3f32a6dae21e015a1",
|
||||
"mask": "a12da23a46848128af68c63aa8ba7a02",
|
||||
},
|
||||
],
|
||||
)
|
||||
@@ -249,22 +252,28 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
@slow
|
||||
def test_integration_torch_image_segmentation(self):
|
||||
model_id = "facebook/detr-resnet-50-panoptic"
|
||||
|
||||
image_segmenter = pipeline("image-segmentation", model=model_id)
|
||||
|
||||
outputs = image_segmenter("http://images.cocodataset.org/val2017/000000039769.jpg")
|
||||
outputs = image_segmenter(
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
task="panoptic",
|
||||
threshold=0,
|
||||
overlap_mask_area_threshold=0.0,
|
||||
)
|
||||
|
||||
# Shortening by hashing
|
||||
for o in outputs:
|
||||
o["mask"] = hashimage(o["mask"])
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.9094, "label": "blanket", "mask": "6500201749480f87154fd967783b2b97"},
|
||||
{"score": 0.9941, "label": "cat", "mask": "f3a7f80220788acc0245ebc084df6afc"},
|
||||
{"score": 0.9987, "label": "remote", "mask": "7703408f54da1d0ebda47841da875e48"},
|
||||
{"score": 0.9995, "label": "remote", "mask": "bd726918f10fed3efaef0091e11f923b"},
|
||||
{"score": 0.9722, "label": "couch", "mask": "226d6dcb98bebc3fbc208abdc0c83196"},
|
||||
{"score": 0.9994, "label": "cat", "mask": "fa5d8d5c329546ba5339f3095641ef56"},
|
||||
{"score": 0.9094, "label": "blanket", "mask": "dcff19a97abd8bd555e21186ae7c066a"},
|
||||
{"score": 0.9941, "label": "cat", "mask": "9c0af87bd00f9d3a4e0c8888e34e70e2"},
|
||||
{"score": 0.9987, "label": "remote", "mask": "c7870600d6c02a1f6d96470fc7220e8e"},
|
||||
{"score": 0.9995, "label": "remote", "mask": "ef899a25fd44ec056c653f0ca2954fdd"},
|
||||
{"score": 0.9722, "label": "couch", "mask": "37b8446ac578a17108aa2b7fccc33114"},
|
||||
{"score": 0.9994, "label": "cat", "mask": "6a09d3655efd8a388ab4511e4cbbb797"},
|
||||
],
|
||||
)
|
||||
|
||||
@@ -273,8 +282,12 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
],
|
||||
task="panoptic",
|
||||
threshold=0.0,
|
||||
overlap_mask_area_threshold=0.0,
|
||||
)
|
||||
|
||||
# Shortening by hashing
|
||||
for output in outputs:
|
||||
for o in output:
|
||||
o["mask"] = hashimage(o["mask"])
|
||||
@@ -283,20 +296,20 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[
|
||||
{"score": 0.9094, "label": "blanket", "mask": "6500201749480f87154fd967783b2b97"},
|
||||
{"score": 0.9941, "label": "cat", "mask": "f3a7f80220788acc0245ebc084df6afc"},
|
||||
{"score": 0.9987, "label": "remote", "mask": "7703408f54da1d0ebda47841da875e48"},
|
||||
{"score": 0.9995, "label": "remote", "mask": "bd726918f10fed3efaef0091e11f923b"},
|
||||
{"score": 0.9722, "label": "couch", "mask": "226d6dcb98bebc3fbc208abdc0c83196"},
|
||||
{"score": 0.9994, "label": "cat", "mask": "fa5d8d5c329546ba5339f3095641ef56"},
|
||||
{"score": 0.9094, "label": "blanket", "mask": "dcff19a97abd8bd555e21186ae7c066a"},
|
||||
{"score": 0.9941, "label": "cat", "mask": "9c0af87bd00f9d3a4e0c8888e34e70e2"},
|
||||
{"score": 0.9987, "label": "remote", "mask": "c7870600d6c02a1f6d96470fc7220e8e"},
|
||||
{"score": 0.9995, "label": "remote", "mask": "ef899a25fd44ec056c653f0ca2954fdd"},
|
||||
{"score": 0.9722, "label": "couch", "mask": "37b8446ac578a17108aa2b7fccc33114"},
|
||||
{"score": 0.9994, "label": "cat", "mask": "6a09d3655efd8a388ab4511e4cbbb797"},
|
||||
],
|
||||
[
|
||||
{"score": 0.9094, "label": "blanket", "mask": "6500201749480f87154fd967783b2b97"},
|
||||
{"score": 0.9941, "label": "cat", "mask": "f3a7f80220788acc0245ebc084df6afc"},
|
||||
{"score": 0.9987, "label": "remote", "mask": "7703408f54da1d0ebda47841da875e48"},
|
||||
{"score": 0.9995, "label": "remote", "mask": "bd726918f10fed3efaef0091e11f923b"},
|
||||
{"score": 0.9722, "label": "couch", "mask": "226d6dcb98bebc3fbc208abdc0c83196"},
|
||||
{"score": 0.9994, "label": "cat", "mask": "fa5d8d5c329546ba5339f3095641ef56"},
|
||||
{"score": 0.9094, "label": "blanket", "mask": "dcff19a97abd8bd555e21186ae7c066a"},
|
||||
{"score": 0.9941, "label": "cat", "mask": "9c0af87bd00f9d3a4e0c8888e34e70e2"},
|
||||
{"score": 0.9987, "label": "remote", "mask": "c7870600d6c02a1f6d96470fc7220e8e"},
|
||||
{"score": 0.9995, "label": "remote", "mask": "ef899a25fd44ec056c653f0ca2954fdd"},
|
||||
{"score": 0.9722, "label": "couch", "mask": "37b8446ac578a17108aa2b7fccc33114"},
|
||||
{"score": 0.9994, "label": "cat", "mask": "6a09d3655efd8a388ab4511e4cbbb797"},
|
||||
],
|
||||
],
|
||||
)
|
||||
@@ -304,12 +317,27 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
@require_torch
|
||||
@slow
|
||||
def test_threshold(self):
|
||||
threshold = 0.999
|
||||
model_id = "facebook/detr-resnet-50-panoptic"
|
||||
|
||||
image_segmenter = pipeline("image-segmentation", model=model_id)
|
||||
|
||||
outputs = image_segmenter("http://images.cocodataset.org/val2017/000000039769.jpg", threshold=threshold)
|
||||
outputs = image_segmenter(
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg", task="panoptic", threshold=0.999
|
||||
)
|
||||
# Shortening by hashing
|
||||
for o in outputs:
|
||||
o["mask"] = hashimage(o["mask"])
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.9995, "label": "remote", "mask": "d02404f5789f075e3b3174adbc3fd5b8"},
|
||||
{"score": 0.9994, "label": "cat", "mask": "eaa115b40c96d3a6f4fe498963a7e470"},
|
||||
],
|
||||
)
|
||||
|
||||
outputs = image_segmenter(
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg", task="panoptic", threshold=0.5
|
||||
)
|
||||
|
||||
for o in outputs:
|
||||
o["mask"] = hashimage(o["mask"])
|
||||
@@ -317,8 +345,11 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.9995, "label": "remote", "mask": "bd726918f10fed3efaef0091e11f923b"},
|
||||
{"score": 0.9994, "label": "cat", "mask": "fa5d8d5c329546ba5339f3095641ef56"},
|
||||
{"score": 0.9941, "label": "cat", "mask": "9c0af87bd00f9d3a4e0c8888e34e70e2"},
|
||||
{"score": 0.9987, "label": "remote", "mask": "c7870600d6c02a1f6d96470fc7220e8e"},
|
||||
{"score": 0.9995, "label": "remote", "mask": "ef899a25fd44ec056c653f0ca2954fdd"},
|
||||
{"score": 0.9722, "label": "couch", "mask": "37b8446ac578a17108aa2b7fccc33114"},
|
||||
{"score": 0.9994, "label": "cat", "mask": "6a09d3655efd8a388ab4511e4cbbb797"},
|
||||
],
|
||||
)
|
||||
|
||||
@@ -335,20 +366,21 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
|
||||
image = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||
file = image[0]["file"]
|
||||
outputs = image_segmenter(file, threshold=threshold)
|
||||
outputs = image_segmenter(file, task="panoptic", threshold=threshold)
|
||||
|
||||
# Shortening by hashing
|
||||
for o in outputs:
|
||||
o["mask"] = hashimage(o["mask"])
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"mask": "20d1b9480d1dc1501dbdcfdff483e370", "label": "wall", "score": None},
|
||||
{"mask": "0f902fbc66a0ff711ea455b0e4943adf", "label": "house", "score": None},
|
||||
{"mask": "4537bdc07d47d84b3f8634b7ada37bd4", "label": "grass", "score": None},
|
||||
{"mask": "b7ac77dfae44a904b479a0926a2acaf7", "label": "tree", "score": None},
|
||||
{"mask": "e9bedd56bd40650fb263ce03eb621079", "label": "plant", "score": None},
|
||||
{"mask": "37a609f8c9c1b8db91fbff269f428b20", "label": "road, route", "score": None},
|
||||
{"mask": "0d8cdfd63bae8bf6e4344d460a2fa711", "label": "sky", "score": None},
|
||||
{"score": 0.9974, "label": "wall", "mask": "a547b7c062917f4f3e36501827ad3cd6"},
|
||||
{"score": 0.949, "label": "house", "mask": "0da9b7b38feac47bd2528a63e5ea7b19"},
|
||||
{"score": 0.9995, "label": "grass", "mask": "1d07ea0a263dcf38ca8ae1a15fdceda1"},
|
||||
{"score": 0.9976, "label": "tree", "mask": "6cdc97c7daf1dc596fa181f461ddd2ba"},
|
||||
{"score": 0.8239, "label": "plant", "mask": "1ab4ce378f6ceff57d428055cfbd742f"},
|
||||
{"score": 0.9942, "label": "road, route", "mask": "39c5d17be53b2d1b0f46aad8ebb15813"},
|
||||
{"score": 1.0, "label": "sky", "mask": "a3756324a692981510c39b1a59510a36"},
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user