Optionally preprocess segmentation maps for MobileViT (#28420)
* optionally preprocess segmentation maps for mobilevit * changed pretrained model name to that of segmentation model * removed voc-deeplabv3 from model archive list * added preprocess_image and preprocess_mask methods for processing images and segmentation masks respectively * added tests for segmentation masks based on segformer feature extractor * use crop_size instead of size * reverting to initial model
This commit is contained in:
@@ -19,12 +19,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import (
|
||||
flip_channel_order,
|
||||
get_resize_output_image_size,
|
||||
resize,
|
||||
to_channel_dimension_format,
|
||||
)
|
||||
from ...image_transforms import flip_channel_order, get_resize_output_image_size, resize, to_channel_dimension_format
|
||||
from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
@@ -178,9 +173,126 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
"""
|
||||
return flip_channel_order(image, data_format=data_format, input_data_format=input_data_format)
|
||||
|
||||
def __call__(self, images, segmentation_maps=None, **kwargs):
|
||||
"""
|
||||
Preprocesses a batch of images and optionally segmentation maps.
|
||||
|
||||
Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be
|
||||
passed in as positional arguments.
|
||||
"""
|
||||
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
image: ImageInput,
|
||||
do_resize: bool,
|
||||
do_rescale: bool,
|
||||
do_center_crop: bool,
|
||||
do_flip_channel_order: bool,
|
||||
size: Optional[Dict[str, int]] = None,
|
||||
resample: PILImageResampling = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
crop_size: Optional[Dict[str, int]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
if do_resize:
|
||||
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
|
||||
if do_rescale:
|
||||
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
|
||||
if do_center_crop:
|
||||
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
|
||||
|
||||
if do_flip_channel_order:
|
||||
image = self.flip_channel_order(image, input_data_format=input_data_format)
|
||||
|
||||
return image
|
||||
|
||||
def _preprocess_image(
|
||||
self,
|
||||
image: ImageInput,
|
||||
do_resize: bool = None,
|
||||
size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: bool = None,
|
||||
rescale_factor: float = None,
|
||||
do_center_crop: bool = None,
|
||||
crop_size: Dict[str, int] = None,
|
||||
do_flip_channel_order: bool = None,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Preprocesses a single image."""
|
||||
# All transformations expect numpy arrays.
|
||||
image = to_numpy_array(image)
|
||||
if is_scaled_image(image) and do_rescale:
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
|
||||
image = self._preprocess(
|
||||
image=image,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_center_crop=do_center_crop,
|
||||
crop_size=crop_size,
|
||||
do_flip_channel_order=do_flip_channel_order,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
|
||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
|
||||
return image
|
||||
|
||||
def _preprocess_mask(
|
||||
self,
|
||||
segmentation_map: ImageInput,
|
||||
do_resize: bool = None,
|
||||
size: Dict[str, int] = None,
|
||||
do_center_crop: bool = None,
|
||||
crop_size: Dict[str, int] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Preprocesses a single mask."""
|
||||
segmentation_map = to_numpy_array(segmentation_map)
|
||||
# Add channel dimension if missing - needed for certain transformations
|
||||
if segmentation_map.ndim == 2:
|
||||
added_channel_dim = True
|
||||
segmentation_map = segmentation_map[None, ...]
|
||||
input_data_format = ChannelDimension.FIRST
|
||||
else:
|
||||
added_channel_dim = False
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
|
||||
|
||||
segmentation_map = self._preprocess(
|
||||
image=segmentation_map,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=PILImageResampling.NEAREST,
|
||||
do_rescale=False,
|
||||
do_center_crop=do_center_crop,
|
||||
crop_size=crop_size,
|
||||
do_flip_channel_order=False,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
# Remove extra channel dimension if added for processing
|
||||
if added_channel_dim:
|
||||
segmentation_map = segmentation_map.squeeze(0)
|
||||
segmentation_map = segmentation_map.astype(np.int64)
|
||||
return segmentation_map
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
segmentation_maps: Optional[ImageInput] = None,
|
||||
do_resize: bool = None,
|
||||
size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = None,
|
||||
@@ -201,6 +313,8 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
images (`ImageInput`):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
segmentation_maps (`ImageInput`, *optional*):
|
||||
Segmentation map to preprocess.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
@@ -251,6 +365,8 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||
|
||||
images = make_list_of_images(images)
|
||||
if segmentation_maps is not None:
|
||||
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
|
||||
|
||||
if not valid_images(images):
|
||||
raise ValueError(
|
||||
@@ -258,6 +374,12 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
if segmentation_maps is not None and not valid_images(segmentation_maps):
|
||||
raise ValueError(
|
||||
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
if do_resize and size is None:
|
||||
raise ValueError("Size must be specified if do_resize is True.")
|
||||
|
||||
@@ -267,45 +389,40 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
if do_center_crop and crop_size is None:
|
||||
raise ValueError("Crop size must be specified if do_center_crop is True.")
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if is_scaled_image(images[0]) and do_rescale:
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
images = [
|
||||
self._preprocess_image(
|
||||
image=img,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_center_crop=do_center_crop,
|
||||
crop_size=crop_size,
|
||||
do_flip_channel_order=do_flip_channel_order,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [
|
||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_center_crop:
|
||||
images = [
|
||||
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
# the pretrained checkpoints assume images are BGR, not RGB
|
||||
if do_flip_channel_order:
|
||||
images = [self.flip_channel_order(image=image, input_data_format=input_data_format) for image in images]
|
||||
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
for img in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
|
||||
if segmentation_maps is not None:
|
||||
segmentation_maps = [
|
||||
self._preprocess_mask(
|
||||
segmentation_map=segmentation_map,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
do_center_crop=do_center_crop,
|
||||
crop_size=crop_size,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for segmentation_map in segmentation_maps
|
||||
]
|
||||
|
||||
data["labels"] = segmentation_maps
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
# Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->MobileViT
|
||||
|
||||
@@ -16,13 +16,20 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import MobileViTImageProcessor
|
||||
|
||||
|
||||
@@ -79,6 +86,26 @@ class MobileViTImageProcessingTester(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
def prepare_semantic_single_inputs():
|
||||
dataset = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||
|
||||
image = Image.open(dataset[0]["file"])
|
||||
map = Image.open(dataset[1]["file"])
|
||||
|
||||
return image, map
|
||||
|
||||
|
||||
def prepare_semantic_batch_inputs():
|
||||
dataset = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||
|
||||
image1 = Image.open(dataset[0]["file"])
|
||||
map1 = Image.open(dataset[1]["file"])
|
||||
image2 = Image.open(dataset[2]["file"])
|
||||
map2 = Image.open(dataset[3]["file"])
|
||||
|
||||
return [image1, image2], [map1, map2]
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
@@ -107,3 +134,109 @@ class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 42})
|
||||
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
|
||||
|
||||
def test_call_segmentation_maps(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||
maps = []
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
maps.append(torch.zeros(image.shape[-2:]).long())
|
||||
|
||||
# Test not batched input
|
||||
encoding = image_processing(image_inputs[0], maps[0], return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
# Test batched
|
||||
encoding = image_processing(image_inputs, maps, return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
self.image_processor_tester.batch_size,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
self.image_processor_tester.batch_size,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
# Test not batched input (PIL images)
|
||||
image, segmentation_map = prepare_semantic_single_inputs()
|
||||
|
||||
encoding = image_processing(image, segmentation_map, return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
# Test batched input (PIL images)
|
||||
images, segmentation_maps = prepare_semantic_batch_inputs()
|
||||
|
||||
encoding = image_processing(images, segmentation_maps, return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
2,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
2,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
Reference in New Issue
Block a user