From 6d4cabda2614d86357092585b416c4d08be73382 Mon Sep 17 00:00:00 2001 From: Eduardo Pacheco <69953243+EduardoPach@users.noreply.github.com> Date: Fri, 26 Apr 2024 20:40:12 +0200 Subject: [PATCH] [SegGPT] Fix seggpt image processor (#29550) * Fixed SegGptImageProcessor to handle 2D and 3D prompt mask inputs * Added new test to check prompt mask equivalence * New proposal * Better proposal * Removed unnecessary method * Updated seggpt docs * Introduced do_convert_rgb * nits --- docs/source/en/model_doc/seggpt.md | 5 +- .../models/seggpt/image_processing_seggpt.py | 103 ++++++++---------- .../seggpt/test_image_processing_seggpt.py | 83 +++++++++++++- tests/models/seggpt/test_modeling_seggpt.py | 22 +++- 4 files changed, 148 insertions(+), 65 deletions(-) diff --git a/docs/source/en/model_doc/seggpt.md b/docs/source/en/model_doc/seggpt.md index f821fc14a0..5a68d38fc9 100644 --- a/docs/source/en/model_doc/seggpt.md +++ b/docs/source/en/model_doc/seggpt.md @@ -26,7 +26,8 @@ The abstract from the paper is the following: Tips: - One can use [`SegGptImageProcessor`] to prepare image input, prompt and mask to the model. -- It's highly advisable to pass `num_labels` (not considering background) during preprocessing and postprocessing with [`SegGptImageProcessor`] for your use case. +- One can either use segmentation maps or RGB images as prompt masks. If using the latter make sure to set `do_convert_rgb=False` in the `preprocess` method. +- It's highly advisable to pass `num_labels` when using `segmetantion_maps` (not considering background) during preprocessing and postprocessing with [`SegGptImageProcessor`] for your use case. - When doing inference with [`SegGptForImageSegmentation`] if your `batch_size` is greater than 1 you can use feature ensemble across your images by passing `feature_ensemble=True` in the forward method. Here's how to use the model for one-shot semantic segmentation: @@ -53,7 +54,7 @@ mask_prompt = ds[29]["label"] inputs = image_processor( images=image_input, prompt_images=image_prompt, - prompt_masks=mask_prompt, + segmentation_maps=mask_prompt, num_labels=num_labels, return_tensors="pt" ) diff --git a/src/transformers/models/seggpt/image_processing_seggpt.py b/src/transformers/models/seggpt/image_processing_seggpt.py index 80fb94cdc7..1e4a5e23d0 100644 --- a/src/transformers/models/seggpt/image_processing_seggpt.py +++ b/src/transformers/models/seggpt/image_processing_seggpt.py @@ -26,19 +26,21 @@ from ...image_utils import ( ChannelDimension, ImageInput, PILImageResampling, - get_channel_dimension_axis, infer_channel_dimension_format, is_scaled_image, make_list_of_images, to_numpy_array, valid_images, ) -from ...utils import TensorType, is_torch_available, logging, requires_backends +from ...utils import TensorType, is_torch_available, is_vision_available, logging, requires_backends if is_torch_available(): import torch +if is_vision_available(): + pass + logger = logging.get_logger(__name__) @@ -65,29 +67,10 @@ def build_palette(num_labels: int) -> List[Tuple[int, int]]: return color_list -def get_num_channels(image: np.ndarray, input_data_format: ChannelDimension) -> int: - if image.ndim == 2: - return 0 - - channel_idx = get_channel_dimension_axis(image, input_data_format) - return image.shape[channel_idx] - - def mask_to_rgb( - mask: np.ndarray, - palette: Optional[List[Tuple[int, int]]] = None, - input_data_format: Optional[ChannelDimension] = None, - data_format: Optional[ChannelDimension] = None, + mask: np.ndarray, palette: Optional[List[Tuple[int, int]]] = None, data_format: Optional[ChannelDimension] = None ) -> np.ndarray: - if input_data_format is None and mask.ndim > 2: - input_data_format = infer_channel_dimension_format(mask) - - data_format = data_format if data_format is not None else input_data_format - - num_channels = get_num_channels(mask, input_data_format) - - if num_channels == 3: - return to_channel_dimension_format(mask, data_format, input_data_format) if data_format is not None else mask + data_format = data_format if data_format is not None else ChannelDimension.FIRST if palette is not None: height, width = mask.shape @@ -109,9 +92,7 @@ def mask_to_rgb( else: rgb_mask = np.repeat(mask[None, ...], 3, axis=0) - return ( - to_channel_dimension_format(rgb_mask, data_format, input_data_format) if data_format is not None else rgb_mask - ) + return to_channel_dimension_format(rgb_mask, data_format) class SegGptImageProcessor(BaseImageProcessor): @@ -143,6 +124,9 @@ class SegGptImageProcessor(BaseImageProcessor): image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`): Standard deviation to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the prompt mask to RGB format. Can be overridden by the `do_convert_rgb` parameter in the + `preprocess` method. """ model_input_names = ["pixel_values"] @@ -157,6 +141,7 @@ class SegGptImageProcessor(BaseImageProcessor): do_normalize: bool = True, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, **kwargs, ) -> None: super().__init__(**kwargs) @@ -170,6 +155,7 @@ class SegGptImageProcessor(BaseImageProcessor): self.rescale_factor = rescale_factor self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.do_convert_rgb = do_convert_rgb def get_palette(self, num_labels: int) -> List[Tuple[int, int]]: """Build a palette to map the prompt mask from a single channel to a 3 channel RGB. @@ -188,13 +174,12 @@ class SegGptImageProcessor(BaseImageProcessor): image: np.ndarray, palette: Optional[List[Tuple[int, int]]] = None, data_format: Optional[Union[str, ChannelDimension]] = None, - input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> np.ndarray: - """Convert a mask to RGB format. + """Converts a segmentation map to RGB format. Args: image (`np.ndarray`): - Mask to convert to RGB format. If the mask is already in RGB format, it will be passed through. + Segmentation map with dimensions (height, width) where pixel values represent the class index. palette (`List[Tuple[int, int]]`, *optional*, defaults to `None`): Palette to use to convert the mask to RGB format. If unset, the mask is duplicated across the channel dimension. @@ -203,21 +188,11 @@ class SegGptImageProcessor(BaseImageProcessor): image is used. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - input_data_format (`ChannelDimension` or `str`, *optional*): - The channel dimension format for the input image. If unset, the channel dimension format is inferred - from the input image. Can be one of: - - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. Returns: `np.ndarray`: The mask in RGB format. """ - return mask_to_rgb( - image, - palette=palette, - data_format=data_format, - input_data_format=input_data_format, - ) + return mask_to_rgb(image, palette=palette, data_format=data_format) # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC def resize( @@ -271,7 +246,6 @@ class SegGptImageProcessor(BaseImageProcessor): def _preprocess_step( self, images: ImageInput, - is_mask: bool = False, do_resize: Optional[bool] = None, size: Dict[str, int] = None, resample: PILImageResampling = None, @@ -282,6 +256,7 @@ class SegGptImageProcessor(BaseImageProcessor): image_std: Optional[Union[float, List[float]]] = None, data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, + do_convert_rgb: Optional[bool] = None, num_labels: Optional[int] = None, **kwargs, ): @@ -292,9 +267,6 @@ class SegGptImageProcessor(BaseImageProcessor): images (`ImageInput`): Image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. - is_mask (`bool`, *optional*, defaults to `False`): - Whether the image is a mask. If True, the image is converted to RGB using the palette if - `self.num_labels` is specified otherwise RGB is achieved by duplicating the channel. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. size (`Dict[str, int]`, *optional*, defaults to `self.size`): @@ -331,6 +303,10 @@ class SegGptImageProcessor(BaseImageProcessor): - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the prompt mask to RGB format. If `num_labels` is specified, a palette will be built + to map the prompt mask from a single channel to a 3 channel RGB. If unset, the prompt mask is duplicated + across the channel dimension. Must be set to `False` if the prompt mask is already in RGB format. num_labels: (`int`, *optional*): Number of classes in the segmentation task (excluding the background). If specified, a palette will be built, assuming that class_idx 0 is the background, to map the prompt mask from a single class_idx @@ -340,6 +316,7 @@ class SegGptImageProcessor(BaseImageProcessor): do_resize = do_resize if do_resize is not None else self.do_resize do_rescale = do_rescale if do_rescale is not None else self.do_rescale do_normalize = do_normalize if do_normalize is not None else self.do_normalize + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb resample = resample if resample is not None else self.resample rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor image_mean = image_mean if image_mean is not None else self.image_mean @@ -348,7 +325,8 @@ class SegGptImageProcessor(BaseImageProcessor): size = size if size is not None else self.size size_dict = get_size_dict(size) - images = make_list_of_images(images) + # If segmentation map is passed we expect 2D images + images = make_list_of_images(images, expected_ndims=2 if do_convert_rgb else 3) if not valid_images(images): raise ValueError( @@ -374,11 +352,11 @@ class SegGptImageProcessor(BaseImageProcessor): " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." ) - if input_data_format is None and not is_mask: + if input_data_format is None and not do_convert_rgb: # We assume that all images have the same channel dimension format. input_data_format = infer_channel_dimension_format(images[0]) - if is_mask: + if do_convert_rgb: palette = self.get_palette(num_labels) if num_labels is not None else None # Since this is the input for the next transformations its format should be the same as the input_data_format images = [ @@ -423,6 +401,7 @@ class SegGptImageProcessor(BaseImageProcessor): do_normalize: Optional[bool] = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: Optional[bool] = None, num_labels: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, @@ -440,9 +419,12 @@ class SegGptImageProcessor(BaseImageProcessor): Prompt image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. prompt_masks (`ImageInput`): - Prompt mask from prompt image to _preprocess. Expects a single or batch of masks. If the mask masks are - a single channel then it will be converted to RGB using the palette if `self.num_labels` is specified - or by just repeating the channel if not. If the mask is already in RGB format, it will be passed through. + Prompt mask from prompt image to _preprocess that specify prompt_masks value in the preprocessed output. + Can either be in the format of segmentation maps (no channels) or RGB images. If in the format of + RGB images, `do_convert_rgb` should be set to `False`. If in the format of segmentation maps, `num_labels` + specifying `num_labels` is recommended to build a palette to map the prompt mask from a single channel to + a 3 channel RGB. If `num_labels` is not specified, the prompt mask will be duplicated across the channel + dimension. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. size (`Dict[str, int]`, *optional*, defaults to `self.size`): @@ -461,6 +443,16 @@ class SegGptImageProcessor(BaseImageProcessor): Image mean to use if `do_normalize` is set to `True`. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): Image standard deviation to use if `do_normalize` is set to `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the prompt mask to RGB format. If `num_labels` is specified, a palette will be built + to map the prompt mask from a single channel to a 3 channel RGB. If unset, the prompt mask is duplicated + across the channel dimension. Must be set to `False` if the prompt mask is already in RGB format. + num_labels: (`int`, *optional*): + Number of classes in the segmentation task (excluding the background). If specified, a palette will be + built, assuming that class_idx 0 is the background, to map the prompt mask from a plain segmentation map + with no channels to a 3 channel RGB. Not specifying this will result in the prompt mask either being passed + through as is if it is already in RGB format (if `do_convert_rgb` is false) or being duplicated + across the channel dimension. return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. @@ -479,11 +471,6 @@ class SegGptImageProcessor(BaseImageProcessor): - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - num_labels: (`int`, *optional*): - Number of classes in the segmentation task (excluding the background). If specified, a palette will be - built, assuming that class_idx 0 is the background, to map the prompt mask from a single class_idx - channel to a 3 channel RGB. Not specifying this will result in the prompt mask either being passed - through as is if it is already in RGB format or being duplicated across the channel dimension. """ if all(v is None for v in [images, prompt_images, prompt_masks]): raise ValueError("At least one of images, prompt_images, prompt_masks must be specified.") @@ -502,6 +489,7 @@ class SegGptImageProcessor(BaseImageProcessor): do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, + do_convert_rgb=False, data_format=data_format, input_data_format=input_data_format, **kwargs, @@ -521,6 +509,7 @@ class SegGptImageProcessor(BaseImageProcessor): do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, + do_convert_rgb=False, data_format=data_format, input_data_format=input_data_format, **kwargs, @@ -531,7 +520,6 @@ class SegGptImageProcessor(BaseImageProcessor): if prompt_masks is not None: prompt_masks = self._preprocess_step( prompt_masks, - is_mask=True, do_resize=do_resize, size=size, resample=PILImageResampling.NEAREST, @@ -540,9 +528,10 @@ class SegGptImageProcessor(BaseImageProcessor): do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, + do_convert_rgb=do_convert_rgb, + num_labels=num_labels, data_format=data_format, input_data_format=input_data_format, - num_labels=num_labels, **kwargs, ) diff --git a/tests/models/seggpt/test_image_processing_seggpt.py b/tests/models/seggpt/test_image_processing_seggpt.py index 46694d6636..04cefb70d0 100644 --- a/tests/models/seggpt/test_image_processing_seggpt.py +++ b/tests/models/seggpt/test_image_processing_seggpt.py @@ -30,6 +30,8 @@ if is_torch_available(): from transformers.models.seggpt.modeling_seggpt import SegGptImageSegmentationOutput if is_vision_available(): + from PIL import Image + from transformers import SegGptImageProcessor @@ -147,7 +149,7 @@ class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): mask_rgb = mask_binary.convert("RGB") inputs_binary = image_processor(images=None, prompt_masks=mask_binary, return_tensors="pt") - inputs_rgb = image_processor(images=None, prompt_masks=mask_rgb, return_tensors="pt") + inputs_rgb = image_processor(images=None, prompt_masks=mask_rgb, return_tensors="pt", do_convert_rgb=False) self.assertTrue((inputs_binary["prompt_masks"] == inputs_rgb["prompt_masks"]).all().item()) @@ -196,7 +198,11 @@ class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processor = SegGptImageProcessor.from_pretrained("BAAI/seggpt-vit-large") inputs = image_processor( - images=input_image, prompt_images=prompt_image, prompt_masks=prompt_mask, return_tensors="pt" + images=input_image, + prompt_images=prompt_image, + prompt_masks=prompt_mask, + return_tensors="pt", + do_convert_rgb=False, ) # Verify pixel values @@ -229,3 +235,76 @@ class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): torch.allclose(inputs.prompt_pixel_values[0, :, :3, :3], expected_prompt_pixel_values, atol=1e-4) ) self.assertTrue(torch.allclose(inputs.prompt_masks[0, :, :3, :3], expected_prompt_masks, atol=1e-4)) + + def test_prompt_mask_equivalence(self): + image_processor = self.image_processing_class(**self.image_processor_dict) + image_size = self.image_processor_tester.image_size + + # Single Mask Examples + expected_single_shape = [1, 3, image_size, image_size] + + # Single Semantic Map (2D) + image_np_2d = np.ones((image_size, image_size)) + image_pt_2d = torch.ones((image_size, image_size)) + image_pil_2d = Image.fromarray(image_np_2d) + + inputs_np_2d = image_processor(images=None, prompt_masks=image_np_2d, return_tensors="pt") + inputs_pt_2d = image_processor(images=None, prompt_masks=image_pt_2d, return_tensors="pt") + inputs_pil_2d = image_processor(images=None, prompt_masks=image_pil_2d, return_tensors="pt") + + self.assertTrue((inputs_np_2d["prompt_masks"] == inputs_pt_2d["prompt_masks"]).all().item()) + self.assertTrue((inputs_np_2d["prompt_masks"] == inputs_pil_2d["prompt_masks"]).all().item()) + self.assertEqual(list(inputs_np_2d["prompt_masks"].shape), expected_single_shape) + + # Single RGB Images (3D) + image_np_3d = np.ones((3, image_size, image_size)) + image_pt_3d = torch.ones((3, image_size, image_size)) + image_pil_3d = Image.fromarray(image_np_3d.transpose(1, 2, 0).astype(np.uint8)) + + inputs_np_3d = image_processor( + images=None, prompt_masks=image_np_3d, return_tensors="pt", do_convert_rgb=False + ) + inputs_pt_3d = image_processor( + images=None, prompt_masks=image_pt_3d, return_tensors="pt", do_convert_rgb=False + ) + inputs_pil_3d = image_processor( + images=None, prompt_masks=image_pil_3d, return_tensors="pt", do_convert_rgb=False + ) + + self.assertTrue((inputs_np_3d["prompt_masks"] == inputs_pt_3d["prompt_masks"]).all().item()) + self.assertTrue((inputs_np_3d["prompt_masks"] == inputs_pil_3d["prompt_masks"]).all().item()) + self.assertEqual(list(inputs_np_3d["prompt_masks"].shape), expected_single_shape) + + # Batched Examples + expected_batched_shape = [2, 3, image_size, image_size] + + # Batched Semantic Maps (3D) + image_np_2d_batched = np.ones((2, image_size, image_size)) + image_pt_2d_batched = torch.ones((2, image_size, image_size)) + + inputs_np_2d_batched = image_processor(images=None, prompt_masks=image_np_2d_batched, return_tensors="pt") + inputs_pt_2d_batched = image_processor(images=None, prompt_masks=image_pt_2d_batched, return_tensors="pt") + + self.assertTrue((inputs_np_2d_batched["prompt_masks"] == inputs_pt_2d_batched["prompt_masks"]).all().item()) + self.assertEqual(list(inputs_np_2d_batched["prompt_masks"].shape), expected_batched_shape) + + # Batched RGB images + image_np_4d = np.ones((2, 3, image_size, image_size)) + image_pt_4d = torch.ones((2, 3, image_size, image_size)) + + inputs_np_4d = image_processor( + images=None, prompt_masks=image_np_4d, return_tensors="pt", do_convert_rgb=False + ) + inputs_pt_4d = image_processor( + images=None, prompt_masks=image_pt_4d, return_tensors="pt", do_convert_rgb=False + ) + + self.assertTrue((inputs_np_4d["prompt_masks"] == inputs_pt_4d["prompt_masks"]).all().item()) + self.assertEqual(list(inputs_np_4d["prompt_masks"].shape), expected_batched_shape) + + # Comparing Single and Batched Examples + self.assertTrue((inputs_np_2d["prompt_masks"][0] == inputs_np_3d["prompt_masks"][0]).all().item()) + self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_2d["prompt_masks"][0]).all().item()) + self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_3d["prompt_masks"][0]).all().item()) + self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_4d["prompt_masks"][0]).all().item()) + self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_3d["prompt_masks"][0]).all().item()) diff --git a/tests/models/seggpt/test_modeling_seggpt.py b/tests/models/seggpt/test_modeling_seggpt.py index d43d430453..efa0231c1e 100644 --- a/tests/models/seggpt/test_modeling_seggpt.py +++ b/tests/models/seggpt/test_modeling_seggpt.py @@ -363,7 +363,11 @@ class SegGptModelIntegrationTest(unittest.TestCase): prompt_mask = masks[0] inputs = image_processor( - images=input_image, prompt_images=prompt_image, prompt_masks=prompt_mask, return_tensors="pt" + images=input_image, + prompt_images=prompt_image, + prompt_masks=prompt_mask, + return_tensors="pt", + do_convert_rgb=False, ) inputs = inputs.to(torch_device) @@ -404,7 +408,11 @@ class SegGptModelIntegrationTest(unittest.TestCase): prompt_masks = [masks[0], masks[2]] inputs = image_processor( - images=input_images, prompt_images=prompt_images, prompt_masks=prompt_masks, return_tensors="pt" + images=input_images, + prompt_images=prompt_images, + prompt_masks=prompt_masks, + return_tensors="pt", + do_convert_rgb=False, ) inputs = {k: v.to(torch_device) for k, v in inputs.items()} @@ -437,10 +445,16 @@ class SegGptModelIntegrationTest(unittest.TestCase): prompt_mask = masks[0] inputs = image_processor( - images=input_image, prompt_masks=prompt_mask, prompt_images=prompt_image, return_tensors="pt" + images=input_image, + prompt_masks=prompt_mask, + prompt_images=prompt_image, + return_tensors="pt", + do_convert_rgb=False, ).to(torch_device) - labels = image_processor(images=None, prompt_masks=label, return_tensors="pt")["prompt_masks"].to(torch_device) + labels = image_processor(images=None, prompt_masks=label, return_tensors="pt", do_convert_rgb=False)[ + "prompt_masks" + ].to(torch_device) bool_masked_pos = prepare_bool_masked_pos(model.config).to(torch_device)