Improving image-segmentation pipeline tests. (#19710)
This PR (https://github.com/huggingface/transformers/pull/19367) introduced a few breaking changes: - Removed an argument `mask_threshold`. - Broke the default behavior (instance vs panoptic in the function call) https://github.com/huggingface/transformers/pull/19367/files#diff-60f846b86fb6a21d4caf60f5b3d593a04accb8f248de3029cccae2ff898c5bc3R119-R120 - Broke the actual masks: https://github.com/huggingface/transformers/pull/1961 This PR is the start of a handful that will aim at bringing back the old behavior(s). - tests should not have to specify `task` by default, unless we want to modify the behavior and have a lower form of segmentation running) - `test_small_model_pt` should be working. This specific PR starts with adding more information to the masks hash because missing the actual mask was actual easy to miss (the hashes do change, but it was easy to miss that one code path wasn't properly updated). So we go from a simple `hash` to ``` {"hash": #smaller hash, "shape": (h, w), "white_pixels": n} ``` The `shape` should help make sure the interpolation of the mask works correctly, the `white_pixels` hopefully helps detect big regressions in their amount when the hash gets modified.
This commit is contained in:
@@ -14,8 +14,10 @@
|
||||
|
||||
import hashlib
|
||||
import unittest
|
||||
from typing import Dict
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import (
|
||||
@@ -48,7 +50,14 @@ else:
|
||||
|
||||
def hashimage(image: Image) -> str:
|
||||
m = hashlib.md5(image.tobytes())
|
||||
return m.hexdigest()
|
||||
return m.hexdigest()[:10]
|
||||
|
||||
|
||||
def mask_to_test_readable(mask: Image) -> Dict:
|
||||
npimg = np.array(mask)
|
||||
white_pixels = (npimg == 255).sum()
|
||||
shape = npimg.shape
|
||||
return {"hash": hashimage(mask), "white_pixels": white_pixels, "shape": shape}
|
||||
|
||||
|
||||
@require_vision
|
||||
@@ -155,7 +164,7 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
|
||||
# Shortening by hashing
|
||||
for o in outputs:
|
||||
o["mask"] = hashimage(o["mask"])
|
||||
o["mask"] = mask_to_test_readable(o["mask"])
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
@@ -163,12 +172,12 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
{
|
||||
"score": 0.004,
|
||||
"label": "LABEL_215",
|
||||
"mask": "34eecd16bbfb0f476083ef947d81bf66",
|
||||
"mask": {"hash": "34eecd16bb", "shape": (480, 640), "white_pixels": 0},
|
||||
},
|
||||
{
|
||||
"score": 0.004,
|
||||
"label": "LABEL_215",
|
||||
"mask": "34eecd16bbfb0f476083ef947d81bf66",
|
||||
"mask": {"hash": "34eecd16bb", "shape": (480, 640), "white_pixels": 0},
|
||||
},
|
||||
],
|
||||
)
|
||||
@@ -182,7 +191,7 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
)
|
||||
for output in outputs:
|
||||
for o in output:
|
||||
o["mask"] = hashimage(o["mask"])
|
||||
o["mask"] = mask_to_test_readable(o["mask"])
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
@@ -191,24 +200,24 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
{
|
||||
"score": 0.004,
|
||||
"label": "LABEL_215",
|
||||
"mask": "34eecd16bbfb0f476083ef947d81bf66",
|
||||
"mask": {"hash": "34eecd16bb", "shape": (480, 640), "white_pixels": 0},
|
||||
},
|
||||
{
|
||||
"score": 0.004,
|
||||
"label": "LABEL_215",
|
||||
"mask": "34eecd16bbfb0f476083ef947d81bf66",
|
||||
"mask": {"hash": "34eecd16bb", "shape": (480, 640), "white_pixels": 0},
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"score": 0.004,
|
||||
"label": "LABEL_215",
|
||||
"mask": "34eecd16bbfb0f476083ef947d81bf66",
|
||||
"mask": {"hash": "34eecd16bb", "shape": (480, 640), "white_pixels": 0},
|
||||
},
|
||||
{
|
||||
"score": 0.004,
|
||||
"label": "LABEL_215",
|
||||
"mask": "34eecd16bbfb0f476083ef947d81bf66",
|
||||
"mask": {"hash": "34eecd16bb", "shape": (480, 640), "white_pixels": 0},
|
||||
},
|
||||
],
|
||||
],
|
||||
@@ -221,16 +230,20 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
outputs = image_segmenter("http://images.cocodataset.org/val2017/000000039769.jpg")
|
||||
for o in outputs:
|
||||
# shortening by hashing
|
||||
o["mask"] = hashimage(o["mask"])
|
||||
o["mask"] = mask_to_test_readable(o["mask"])
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": None, "label": "LABEL_0", "mask": "42d09072282a32da2ac77375a4c1280f"},
|
||||
{
|
||||
"score": None,
|
||||
"label": "LABEL_0",
|
||||
"mask": {"hash": "42d0907228", "shape": (480, 640), "white_pixels": 10714},
|
||||
},
|
||||
{
|
||||
"score": None,
|
||||
"label": "LABEL_1",
|
||||
"mask": "46b8cc3976732873b219f77a1213c1a5",
|
||||
"mask": {"hash": "46b8cc3976", "shape": (480, 640), "white_pixels": 296486},
|
||||
},
|
||||
],
|
||||
)
|
||||
@@ -250,17 +263,41 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
|
||||
# Shortening by hashing
|
||||
for o in outputs:
|
||||
o["mask"] = hashimage(o["mask"])
|
||||
o["mask"] = mask_to_test_readable(o["mask"])
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.9094, "label": "blanket", "mask": "dcff19a97abd8bd555e21186ae7c066a"},
|
||||
{"score": 0.9941, "label": "cat", "mask": "9c0af87bd00f9d3a4e0c8888e34e70e2"},
|
||||
{"score": 0.9987, "label": "remote", "mask": "c7870600d6c02a1f6d96470fc7220e8e"},
|
||||
{"score": 0.9995, "label": "remote", "mask": "ef899a25fd44ec056c653f0ca2954fdd"},
|
||||
{"score": 0.9722, "label": "couch", "mask": "37b8446ac578a17108aa2b7fccc33114"},
|
||||
{"score": 0.9994, "label": "cat", "mask": "6a09d3655efd8a388ab4511e4cbbb797"},
|
||||
{
|
||||
"score": 0.9094,
|
||||
"label": "blanket",
|
||||
"mask": {"hash": "dcff19a97a", "shape": (480, 640), "white_pixels": 16617},
|
||||
},
|
||||
{
|
||||
"score": 0.9941,
|
||||
"label": "cat",
|
||||
"mask": {"hash": "9c0af87bd0", "shape": (480, 640), "white_pixels": 59185},
|
||||
},
|
||||
{
|
||||
"score": 0.9987,
|
||||
"label": "remote",
|
||||
"mask": {"hash": "c7870600d6", "shape": (480, 640), "white_pixels": 4182},
|
||||
},
|
||||
{
|
||||
"score": 0.9995,
|
||||
"label": "remote",
|
||||
"mask": {"hash": "ef899a25fd", "shape": (480, 640), "white_pixels": 2275},
|
||||
},
|
||||
{
|
||||
"score": 0.9722,
|
||||
"label": "couch",
|
||||
"mask": {"hash": "37b8446ac5", "shape": (480, 640), "white_pixels": 172380},
|
||||
},
|
||||
{
|
||||
"score": 0.9994,
|
||||
"label": "cat",
|
||||
"mask": {"hash": "6a09d3655e", "shape": (480, 640), "white_pixels": 52561},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@@ -277,26 +314,74 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
# Shortening by hashing
|
||||
for output in outputs:
|
||||
for o in output:
|
||||
o["mask"] = hashimage(o["mask"])
|
||||
o["mask"] = mask_to_test_readable(o["mask"])
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[
|
||||
{"score": 0.9094, "label": "blanket", "mask": "dcff19a97abd8bd555e21186ae7c066a"},
|
||||
{"score": 0.9941, "label": "cat", "mask": "9c0af87bd00f9d3a4e0c8888e34e70e2"},
|
||||
{"score": 0.9987, "label": "remote", "mask": "c7870600d6c02a1f6d96470fc7220e8e"},
|
||||
{"score": 0.9995, "label": "remote", "mask": "ef899a25fd44ec056c653f0ca2954fdd"},
|
||||
{"score": 0.9722, "label": "couch", "mask": "37b8446ac578a17108aa2b7fccc33114"},
|
||||
{"score": 0.9994, "label": "cat", "mask": "6a09d3655efd8a388ab4511e4cbbb797"},
|
||||
{
|
||||
"score": 0.9094,
|
||||
"label": "blanket",
|
||||
"mask": {"hash": "dcff19a97a", "shape": (480, 640), "white_pixels": 16617},
|
||||
},
|
||||
{
|
||||
"score": 0.9941,
|
||||
"label": "cat",
|
||||
"mask": {"hash": "9c0af87bd0", "shape": (480, 640), "white_pixels": 59185},
|
||||
},
|
||||
{
|
||||
"score": 0.9987,
|
||||
"label": "remote",
|
||||
"mask": {"hash": "c7870600d6", "shape": (480, 640), "white_pixels": 4182},
|
||||
},
|
||||
{
|
||||
"score": 0.9995,
|
||||
"label": "remote",
|
||||
"mask": {"hash": "ef899a25fd", "shape": (480, 640), "white_pixels": 2275},
|
||||
},
|
||||
{
|
||||
"score": 0.9722,
|
||||
"label": "couch",
|
||||
"mask": {"hash": "37b8446ac5", "shape": (480, 640), "white_pixels": 172380},
|
||||
},
|
||||
{
|
||||
"score": 0.9994,
|
||||
"label": "cat",
|
||||
"mask": {"hash": "6a09d3655e", "shape": (480, 640), "white_pixels": 52561},
|
||||
},
|
||||
],
|
||||
[
|
||||
{"score": 0.9094, "label": "blanket", "mask": "dcff19a97abd8bd555e21186ae7c066a"},
|
||||
{"score": 0.9941, "label": "cat", "mask": "9c0af87bd00f9d3a4e0c8888e34e70e2"},
|
||||
{"score": 0.9987, "label": "remote", "mask": "c7870600d6c02a1f6d96470fc7220e8e"},
|
||||
{"score": 0.9995, "label": "remote", "mask": "ef899a25fd44ec056c653f0ca2954fdd"},
|
||||
{"score": 0.9722, "label": "couch", "mask": "37b8446ac578a17108aa2b7fccc33114"},
|
||||
{"score": 0.9994, "label": "cat", "mask": "6a09d3655efd8a388ab4511e4cbbb797"},
|
||||
{
|
||||
"score": 0.9094,
|
||||
"label": "blanket",
|
||||
"mask": {"hash": "dcff19a97a", "shape": (480, 640), "white_pixels": 16617},
|
||||
},
|
||||
{
|
||||
"score": 0.9941,
|
||||
"label": "cat",
|
||||
"mask": {"hash": "9c0af87bd0", "shape": (480, 640), "white_pixels": 59185},
|
||||
},
|
||||
{
|
||||
"score": 0.9987,
|
||||
"label": "remote",
|
||||
"mask": {"hash": "c7870600d6", "shape": (480, 640), "white_pixels": 4182},
|
||||
},
|
||||
{
|
||||
"score": 0.9995,
|
||||
"label": "remote",
|
||||
"mask": {"hash": "ef899a25fd", "shape": (480, 640), "white_pixels": 2275},
|
||||
},
|
||||
{
|
||||
"score": 0.9722,
|
||||
"label": "couch",
|
||||
"mask": {"hash": "37b8446ac5", "shape": (480, 640), "white_pixels": 172380},
|
||||
},
|
||||
{
|
||||
"score": 0.9994,
|
||||
"label": "cat",
|
||||
"mask": {"hash": "6a09d3655e", "shape": (480, 640), "white_pixels": 52561},
|
||||
},
|
||||
],
|
||||
],
|
||||
)
|
||||
@@ -312,13 +397,21 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
)
|
||||
# Shortening by hashing
|
||||
for o in outputs:
|
||||
o["mask"] = hashimage(o["mask"])
|
||||
o["mask"] = mask_to_test_readable(o["mask"])
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.9995, "label": "remote", "mask": "d02404f5789f075e3b3174adbc3fd5b8"},
|
||||
{"score": 0.9994, "label": "cat", "mask": "eaa115b40c96d3a6f4fe498963a7e470"},
|
||||
{
|
||||
"score": 0.9995,
|
||||
"label": "remote",
|
||||
"mask": {"hash": "d02404f578", "shape": (480, 640), "white_pixels": 2789},
|
||||
},
|
||||
{
|
||||
"score": 0.9994,
|
||||
"label": "cat",
|
||||
"mask": {"hash": "eaa115b40c", "shape": (480, 640), "white_pixels": 304411},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@@ -327,16 +420,36 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
)
|
||||
|
||||
for o in outputs:
|
||||
o["mask"] = hashimage(o["mask"])
|
||||
o["mask"] = mask_to_test_readable(o["mask"])
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.9941, "label": "cat", "mask": "9c0af87bd00f9d3a4e0c8888e34e70e2"},
|
||||
{"score": 0.9987, "label": "remote", "mask": "c7870600d6c02a1f6d96470fc7220e8e"},
|
||||
{"score": 0.9995, "label": "remote", "mask": "ef899a25fd44ec056c653f0ca2954fdd"},
|
||||
{"score": 0.9722, "label": "couch", "mask": "37b8446ac578a17108aa2b7fccc33114"},
|
||||
{"score": 0.9994, "label": "cat", "mask": "6a09d3655efd8a388ab4511e4cbbb797"},
|
||||
{
|
||||
"score": 0.9941,
|
||||
"label": "cat",
|
||||
"mask": {"hash": "9c0af87bd0", "shape": (480, 640), "white_pixels": 59185},
|
||||
},
|
||||
{
|
||||
"score": 0.9987,
|
||||
"label": "remote",
|
||||
"mask": {"hash": "c7870600d6", "shape": (480, 640), "white_pixels": 4182},
|
||||
},
|
||||
{
|
||||
"score": 0.9995,
|
||||
"label": "remote",
|
||||
"mask": {"hash": "ef899a25fd", "shape": (480, 640), "white_pixels": 2275},
|
||||
},
|
||||
{
|
||||
"score": 0.9722,
|
||||
"label": "couch",
|
||||
"mask": {"hash": "37b8446ac5", "shape": (480, 640), "white_pixels": 172380},
|
||||
},
|
||||
{
|
||||
"score": 0.9994,
|
||||
"label": "cat",
|
||||
"mask": {"hash": "6a09d3655e", "shape": (480, 640), "white_pixels": 52561},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@@ -357,17 +470,45 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
|
||||
# Shortening by hashing
|
||||
for o in outputs:
|
||||
o["mask"] = hashimage(o["mask"])
|
||||
o["mask"] = mask_to_test_readable(o["mask"])
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.9974, "label": "wall", "mask": "a547b7c062917f4f3e36501827ad3cd6"},
|
||||
{"score": 0.949, "label": "house", "mask": "0da9b7b38feac47bd2528a63e5ea7b19"},
|
||||
{"score": 0.9995, "label": "grass", "mask": "1d07ea0a263dcf38ca8ae1a15fdceda1"},
|
||||
{"score": 0.9976, "label": "tree", "mask": "6cdc97c7daf1dc596fa181f461ddd2ba"},
|
||||
{"score": 0.8239, "label": "plant", "mask": "1ab4ce378f6ceff57d428055cfbd742f"},
|
||||
{"score": 0.9942, "label": "road, route", "mask": "39c5d17be53b2d1b0f46aad8ebb15813"},
|
||||
{"score": 1.0, "label": "sky", "mask": "a3756324a692981510c39b1a59510a36"},
|
||||
{
|
||||
"score": 0.9974,
|
||||
"label": "wall",
|
||||
"mask": {"hash": "a547b7c062", "shape": (512, 683), "white_pixels": 14252},
|
||||
},
|
||||
{
|
||||
"score": 0.949,
|
||||
"label": "house",
|
||||
"mask": {"hash": "0da9b7b38f", "shape": (512, 683), "white_pixels": 132177},
|
||||
},
|
||||
{
|
||||
"score": 0.9995,
|
||||
"label": "grass",
|
||||
"mask": {"hash": "1d07ea0a26", "shape": (512, 683), "white_pixels": 53444},
|
||||
},
|
||||
{
|
||||
"score": 0.9976,
|
||||
"label": "tree",
|
||||
"mask": {"hash": "6cdc97c7da", "shape": (512, 683), "white_pixels": 7944},
|
||||
},
|
||||
{
|
||||
"score": 0.8239,
|
||||
"label": "plant",
|
||||
"mask": {"hash": "1ab4ce378f", "shape": (512, 683), "white_pixels": 4136},
|
||||
},
|
||||
{
|
||||
"score": 0.9942,
|
||||
"label": "road, route",
|
||||
"mask": {"hash": "39c5d17be5", "shape": (512, 683), "white_pixels": 1941},
|
||||
},
|
||||
{
|
||||
"score": 1.0,
|
||||
"label": "sky",
|
||||
"mask": {"hash": "a3756324a6", "shape": (512, 683), "white_pixels": 135802},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user