Cast masks to np.unit8 before converting to PIL.Image.Image (#19616)
* Cast masks to np.unit8 before converting to PIL.Image.Image * Update tests * Fixup
This commit is contained in:
@@ -172,7 +172,7 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
|
||||
for label in labels:
|
||||
mask = (segmentation == label) * 255
|
||||
mask = Image.fromarray(mask, mode="L")
|
||||
mask = Image.fromarray(mask.astype(np.uint8), mode="L")
|
||||
label = self.model.config.id2label[label]
|
||||
annotation.append({"score": None, "label": label, "mask": mask})
|
||||
else:
|
||||
|
||||
@@ -226,15 +226,11 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{
|
||||
"score": None,
|
||||
"label": "LABEL_0",
|
||||
"mask": "775518a7ed09eea888752176c6ba8f38",
|
||||
},
|
||||
{"score": None, "label": "LABEL_0", "mask": "42d09072282a32da2ac77375a4c1280f"},
|
||||
{
|
||||
"score": None,
|
||||
"label": "LABEL_1",
|
||||
"mask": "a12da23a46848128af68c63aa8ba7a02",
|
||||
"mask": "46b8cc3976732873b219f77a1213c1a5",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user