fixed Mask2Former image processor segmentation maps handling (#33364)
* fixed mask2former image processor segmentation maps handling * introduced review suggestions * introduced review suggestions
This commit is contained in:
@@ -935,7 +935,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
|||||||
if segmentation_maps is not None:
|
if segmentation_maps is not None:
|
||||||
mask_labels = []
|
mask_labels = []
|
||||||
class_labels = []
|
class_labels = []
|
||||||
pad_size = get_max_height_width(pixel_values_list)
|
pad_size = get_max_height_width(pixel_values_list, input_data_format=input_data_format)
|
||||||
# Convert to list of binary masks and labels
|
# Convert to list of binary masks and labels
|
||||||
for idx, segmentation_map in enumerate(segmentation_maps):
|
for idx, segmentation_map in enumerate(segmentation_maps):
|
||||||
segmentation_map = to_numpy_array(segmentation_map)
|
segmentation_map = to_numpy_array(segmentation_map)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import numpy as np
|
|||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
from transformers.image_utils import ChannelDimension
|
||||||
from transformers.testing_utils import require_torch, require_vision
|
from transformers.testing_utils import require_torch, require_vision
|
||||||
from transformers.utils import is_torch_available, is_vision_available
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
@@ -180,31 +181,44 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
|
|||||||
self.assertEqual(image_processor.size_divisor, 8)
|
self.assertEqual(image_processor.size_divisor, 8)
|
||||||
|
|
||||||
def comm_get_image_processing_inputs(
|
def comm_get_image_processing_inputs(
|
||||||
self, with_segmentation_maps=False, is_instance_map=False, segmentation_type="np"
|
self,
|
||||||
|
image_processor_tester,
|
||||||
|
with_segmentation_maps=False,
|
||||||
|
is_instance_map=False,
|
||||||
|
segmentation_type="np",
|
||||||
|
numpify=False,
|
||||||
|
input_data_format=None,
|
||||||
):
|
):
|
||||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
image_processing = self.image_processing_class(**image_processor_tester.prepare_image_processor_dict())
|
||||||
# prepare image and target
|
# prepare image and target
|
||||||
num_labels = self.image_processor_tester.num_labels
|
num_labels = image_processor_tester.num_labels
|
||||||
annotations = None
|
annotations = None
|
||||||
instance_id_to_semantic_id = None
|
instance_id_to_semantic_id = None
|
||||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
|
image_inputs = image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=numpify)
|
||||||
if with_segmentation_maps:
|
if with_segmentation_maps:
|
||||||
high = num_labels
|
high = num_labels
|
||||||
if is_instance_map:
|
if is_instance_map:
|
||||||
labels_expanded = list(range(num_labels)) * 2
|
labels_expanded = list(range(num_labels)) * 2
|
||||||
instance_id_to_semantic_id = dict(enumerate(labels_expanded))
|
instance_id_to_semantic_id = dict(enumerate(labels_expanded))
|
||||||
annotations = [
|
annotations = [
|
||||||
np.random.randint(0, high * 2, (img.size[1], img.size[0])).astype(np.uint8) for img in image_inputs
|
np.random.randint(0, high * 2, img.shape[:2] if numpify else (img.size[1], img.size[0])).astype(
|
||||||
|
np.uint8
|
||||||
|
)
|
||||||
|
for img in image_inputs
|
||||||
]
|
]
|
||||||
if segmentation_type == "pil":
|
if segmentation_type == "pil":
|
||||||
annotations = [Image.fromarray(annotation) for annotation in annotations]
|
annotations = [Image.fromarray(annotation) for annotation in annotations]
|
||||||
|
|
||||||
|
if input_data_format is ChannelDimension.FIRST and numpify:
|
||||||
|
image_inputs = [np.moveaxis(img, -1, 0) for img in image_inputs]
|
||||||
|
|
||||||
inputs = image_processing(
|
inputs = image_processing(
|
||||||
image_inputs,
|
image_inputs,
|
||||||
annotations,
|
annotations,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
instance_id_to_semantic_id=instance_id_to_semantic_id,
|
instance_id_to_semantic_id=instance_id_to_semantic_id,
|
||||||
pad_and_return_pixel_mask=True,
|
pad_and_return_pixel_mask=True,
|
||||||
|
input_data_format=input_data_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
@@ -223,9 +237,29 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
|
|||||||
self.assertTrue((pixel_values.shape[-2] % size_divisor) == 0)
|
self.assertTrue((pixel_values.shape[-2] % size_divisor) == 0)
|
||||||
|
|
||||||
def test_call_with_segmentation_maps(self):
|
def test_call_with_segmentation_maps(self):
|
||||||
def common(is_instance_map=False, segmentation_type=None):
|
def common(
|
||||||
|
is_instance_map=False,
|
||||||
|
segmentation_type=None,
|
||||||
|
numpify=False,
|
||||||
|
num_channels=3,
|
||||||
|
input_data_format=None,
|
||||||
|
do_resize=True,
|
||||||
|
):
|
||||||
|
image_processor_tester = Mask2FormerImageProcessingTester(
|
||||||
|
self,
|
||||||
|
num_channels=num_channels,
|
||||||
|
do_resize=do_resize,
|
||||||
|
image_mean=[0.5] * num_channels,
|
||||||
|
image_std=[0.5] * num_channels,
|
||||||
|
)
|
||||||
|
|
||||||
inputs = self.comm_get_image_processing_inputs(
|
inputs = self.comm_get_image_processing_inputs(
|
||||||
with_segmentation_maps=True, is_instance_map=is_instance_map, segmentation_type=segmentation_type
|
image_processor_tester=image_processor_tester,
|
||||||
|
with_segmentation_maps=True,
|
||||||
|
is_instance_map=is_instance_map,
|
||||||
|
segmentation_type=segmentation_type,
|
||||||
|
numpify=numpify,
|
||||||
|
input_data_format=input_data_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
mask_labels = inputs["mask_labels"]
|
mask_labels = inputs["mask_labels"]
|
||||||
@@ -243,6 +277,18 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
|
|||||||
common(is_instance_map=False, segmentation_type="pil")
|
common(is_instance_map=False, segmentation_type="pil")
|
||||||
common(is_instance_map=True, segmentation_type="pil")
|
common(is_instance_map=True, segmentation_type="pil")
|
||||||
|
|
||||||
|
common(num_channels=1, numpify=True)
|
||||||
|
common(num_channels=1, numpify=True, input_data_format=ChannelDimension.FIRST)
|
||||||
|
common(num_channels=2, numpify=True, input_data_format=ChannelDimension.LAST)
|
||||||
|
common(num_channels=5, numpify=True, input_data_format=ChannelDimension.LAST, do_resize=False)
|
||||||
|
common(num_channels=5, numpify=True, input_data_format=ChannelDimension.FIRST, do_resize=False)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, expected_regex="Unable to infer channel dimension format"):
|
||||||
|
common(num_channels=5, numpify=True, do_resize=False)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(TypeError, expected_regex=r"Cannot handle this data type: .*"):
|
||||||
|
common(num_channels=5, numpify=True, input_data_format=ChannelDimension.LAST)
|
||||||
|
|
||||||
def test_integration_instance_segmentation(self):
|
def test_integration_instance_segmentation(self):
|
||||||
# load 2 images and corresponding annotations from the hub
|
# load 2 images and corresponding annotations from the hub
|
||||||
repo_id = "nielsr/image-segmentation-toy-data"
|
repo_id = "nielsr/image-segmentation-toy-data"
|
||||||
|
|||||||
Reference in New Issue
Block a user