Cast masks to np.unit8 before converting to PIL.Image.Image (#19616)

* Cast masks to np.unit8 before converting to PIL.Image.Image

* Update tests

* Fixup
This commit is contained in:
amyeroberts
2022-10-14 14:30:45 +01:00
committed by GitHub
parent 909f07092a
commit 83a2e694f1
2 changed files with 3 additions and 7 deletions

View File

@@ -172,7 +172,7 @@ class ImageSegmentationPipeline(Pipeline):
for label in labels: for label in labels:
mask = (segmentation == label) * 255 mask = (segmentation == label) * 255
mask = Image.fromarray(mask, mode="L") mask = Image.fromarray(mask.astype(np.uint8), mode="L")
label = self.model.config.id2label[label] label = self.model.config.id2label[label]
annotation.append({"score": None, "label": label, "mask": mask}) annotation.append({"score": None, "label": label, "mask": mask})
else: else:

View File

@@ -226,15 +226,11 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
self.assertEqual( self.assertEqual(
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
{ {"score": None, "label": "LABEL_0", "mask": "42d09072282a32da2ac77375a4c1280f"},
"score": None,
"label": "LABEL_0",
"mask": "775518a7ed09eea888752176c6ba8f38",
},
{ {
"score": None, "score": None,
"label": "LABEL_1", "label": "LABEL_1",
"mask": "a12da23a46848128af68c63aa8ba7a02", "mask": "46b8cc3976732873b219f77a1213c1a5",
}, },
], ],
) )