From 83a2e694f1208e778488875a567bb69c90028536 Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Fri, 14 Oct 2022 14:30:45 +0100 Subject: [PATCH] Cast masks to np.unit8 before converting to PIL.Image.Image (#19616) * Cast masks to np.unit8 before converting to PIL.Image.Image * Update tests * Fixup --- src/transformers/pipelines/image_segmentation.py | 2 +- tests/pipelines/test_pipelines_image_segmentation.py | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/transformers/pipelines/image_segmentation.py b/src/transformers/pipelines/image_segmentation.py index 690247f6e4..a085360e34 100644 --- a/src/transformers/pipelines/image_segmentation.py +++ b/src/transformers/pipelines/image_segmentation.py @@ -172,7 +172,7 @@ class ImageSegmentationPipeline(Pipeline): for label in labels: mask = (segmentation == label) * 255 - mask = Image.fromarray(mask, mode="L") + mask = Image.fromarray(mask.astype(np.uint8), mode="L") label = self.model.config.id2label[label] annotation.append({"score": None, "label": label, "mask": mask}) else: diff --git a/tests/pipelines/test_pipelines_image_segmentation.py b/tests/pipelines/test_pipelines_image_segmentation.py index 3d7d067afa..42023aeebe 100644 --- a/tests/pipelines/test_pipelines_image_segmentation.py +++ b/tests/pipelines/test_pipelines_image_segmentation.py @@ -226,15 +226,11 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa self.assertEqual( nested_simplify(outputs, decimals=4), [ - { - "score": None, - "label": "LABEL_0", - "mask": "775518a7ed09eea888752176c6ba8f38", - }, + {"score": None, "label": "LABEL_0", "mask": "42d09072282a32da2ac77375a4c1280f"}, { "score": None, "label": "LABEL_1", - "mask": "a12da23a46848128af68c63aa8ba7a02", + "mask": "46b8cc3976732873b219f77a1213c1a5", }, ], )