Optionally preprocess segmentation maps for MobileViT (#28420)
* optionally preprocess segmentation maps for mobilevit * changed pretrained model name to that of segmentation model * removed voc-deeplabv3 from model archive list * added preprocess_image and preprocess_mask methods for processing images and segmentation masks respectively * added tests for segmentation masks based on segformer feature extractor * use crop_size instead of size * reverting to initial model
This commit is contained in:
@@ -19,12 +19,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
from ...image_transforms import (
|
from ...image_transforms import flip_channel_order, get_resize_output_image_size, resize, to_channel_dimension_format
|
||||||
flip_channel_order,
|
|
||||||
get_resize_output_image_size,
|
|
||||||
resize,
|
|
||||||
to_channel_dimension_format,
|
|
||||||
)
|
|
||||||
from ...image_utils import (
|
from ...image_utils import (
|
||||||
ChannelDimension,
|
ChannelDimension,
|
||||||
ImageInput,
|
ImageInput,
|
||||||
@@ -178,9 +173,126 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
|||||||
"""
|
"""
|
||||||
return flip_channel_order(image, data_format=data_format, input_data_format=input_data_format)
|
return flip_channel_order(image, data_format=data_format, input_data_format=input_data_format)
|
||||||
|
|
||||||
|
def __call__(self, images, segmentation_maps=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Preprocesses a batch of images and optionally segmentation maps.
|
||||||
|
|
||||||
|
Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be
|
||||||
|
passed in as positional arguments.
|
||||||
|
"""
|
||||||
|
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
|
||||||
|
|
||||||
|
def _preprocess(
|
||||||
|
self,
|
||||||
|
image: ImageInput,
|
||||||
|
do_resize: bool,
|
||||||
|
do_rescale: bool,
|
||||||
|
do_center_crop: bool,
|
||||||
|
do_flip_channel_order: bool,
|
||||||
|
size: Optional[Dict[str, int]] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
rescale_factor: Optional[float] = None,
|
||||||
|
crop_size: Optional[Dict[str, int]] = None,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
):
|
||||||
|
if do_resize:
|
||||||
|
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||||
|
|
||||||
|
if do_rescale:
|
||||||
|
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||||
|
|
||||||
|
if do_center_crop:
|
||||||
|
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
|
||||||
|
|
||||||
|
if do_flip_channel_order:
|
||||||
|
image = self.flip_channel_order(image, input_data_format=input_data_format)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _preprocess_image(
|
||||||
|
self,
|
||||||
|
image: ImageInput,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_rescale: bool = None,
|
||||||
|
rescale_factor: float = None,
|
||||||
|
do_center_crop: bool = None,
|
||||||
|
crop_size: Dict[str, int] = None,
|
||||||
|
do_flip_channel_order: bool = None,
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Preprocesses a single image."""
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
image = to_numpy_array(image)
|
||||||
|
if is_scaled_image(image) and do_rescale:
|
||||||
|
logger.warning_once(
|
||||||
|
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||||
|
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||||
|
)
|
||||||
|
if input_data_format is None:
|
||||||
|
input_data_format = infer_channel_dimension_format(image)
|
||||||
|
|
||||||
|
image = self._preprocess(
|
||||||
|
image=image,
|
||||||
|
do_resize=do_resize,
|
||||||
|
size=size,
|
||||||
|
resample=resample,
|
||||||
|
do_rescale=do_rescale,
|
||||||
|
rescale_factor=rescale_factor,
|
||||||
|
do_center_crop=do_center_crop,
|
||||||
|
crop_size=crop_size,
|
||||||
|
do_flip_channel_order=do_flip_channel_order,
|
||||||
|
input_data_format=input_data_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _preprocess_mask(
|
||||||
|
self,
|
||||||
|
segmentation_map: ImageInput,
|
||||||
|
do_resize: bool = None,
|
||||||
|
size: Dict[str, int] = None,
|
||||||
|
do_center_crop: bool = None,
|
||||||
|
crop_size: Dict[str, int] = None,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Preprocesses a single mask."""
|
||||||
|
segmentation_map = to_numpy_array(segmentation_map)
|
||||||
|
# Add channel dimension if missing - needed for certain transformations
|
||||||
|
if segmentation_map.ndim == 2:
|
||||||
|
added_channel_dim = True
|
||||||
|
segmentation_map = segmentation_map[None, ...]
|
||||||
|
input_data_format = ChannelDimension.FIRST
|
||||||
|
else:
|
||||||
|
added_channel_dim = False
|
||||||
|
if input_data_format is None:
|
||||||
|
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
|
||||||
|
|
||||||
|
segmentation_map = self._preprocess(
|
||||||
|
image=segmentation_map,
|
||||||
|
do_resize=do_resize,
|
||||||
|
size=size,
|
||||||
|
resample=PILImageResampling.NEAREST,
|
||||||
|
do_rescale=False,
|
||||||
|
do_center_crop=do_center_crop,
|
||||||
|
crop_size=crop_size,
|
||||||
|
do_flip_channel_order=False,
|
||||||
|
input_data_format=input_data_format,
|
||||||
|
)
|
||||||
|
# Remove extra channel dimension if added for processing
|
||||||
|
if added_channel_dim:
|
||||||
|
segmentation_map = segmentation_map.squeeze(0)
|
||||||
|
segmentation_map = segmentation_map.astype(np.int64)
|
||||||
|
return segmentation_map
|
||||||
|
|
||||||
def preprocess(
|
def preprocess(
|
||||||
self,
|
self,
|
||||||
images: ImageInput,
|
images: ImageInput,
|
||||||
|
segmentation_maps: Optional[ImageInput] = None,
|
||||||
do_resize: bool = None,
|
do_resize: bool = None,
|
||||||
size: Dict[str, int] = None,
|
size: Dict[str, int] = None,
|
||||||
resample: PILImageResampling = None,
|
resample: PILImageResampling = None,
|
||||||
@@ -201,6 +313,8 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
|||||||
images (`ImageInput`):
|
images (`ImageInput`):
|
||||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||||
|
segmentation_maps (`ImageInput`, *optional*):
|
||||||
|
Segmentation map to preprocess.
|
||||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
Whether to resize the image.
|
Whether to resize the image.
|
||||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
@@ -251,6 +365,8 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
|||||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||||
|
|
||||||
images = make_list_of_images(images)
|
images = make_list_of_images(images)
|
||||||
|
if segmentation_maps is not None:
|
||||||
|
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
|
||||||
|
|
||||||
if not valid_images(images):
|
if not valid_images(images):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -258,6 +374,12 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
|||||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if segmentation_maps is not None and not valid_images(segmentation_maps):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
if do_resize and size is None:
|
if do_resize and size is None:
|
||||||
raise ValueError("Size must be specified if do_resize is True.")
|
raise ValueError("Size must be specified if do_resize is True.")
|
||||||
|
|
||||||
@@ -267,45 +389,40 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
|||||||
if do_center_crop and crop_size is None:
|
if do_center_crop and crop_size is None:
|
||||||
raise ValueError("Crop size must be specified if do_center_crop is True.")
|
raise ValueError("Crop size must be specified if do_center_crop is True.")
|
||||||
|
|
||||||
# All transformations expect numpy arrays.
|
|
||||||
images = [to_numpy_array(image) for image in images]
|
|
||||||
|
|
||||||
if is_scaled_image(images[0]) and do_rescale:
|
|
||||||
logger.warning_once(
|
|
||||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
|
||||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
|
||||||
)
|
|
||||||
|
|
||||||
if input_data_format is None:
|
|
||||||
# We assume that all images have the same channel dimension format.
|
|
||||||
input_data_format = infer_channel_dimension_format(images[0])
|
|
||||||
|
|
||||||
if do_resize:
|
|
||||||
images = [
|
|
||||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
|
||||||
for image in images
|
|
||||||
]
|
|
||||||
|
|
||||||
if do_center_crop:
|
|
||||||
images = [
|
|
||||||
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
|
|
||||||
]
|
|
||||||
|
|
||||||
if do_rescale:
|
|
||||||
images = [
|
|
||||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
|
||||||
for image in images
|
|
||||||
]
|
|
||||||
|
|
||||||
# the pretrained checkpoints assume images are BGR, not RGB
|
|
||||||
if do_flip_channel_order:
|
|
||||||
images = [self.flip_channel_order(image=image, input_data_format=input_data_format) for image in images]
|
|
||||||
|
|
||||||
images = [
|
images = [
|
||||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
self._preprocess_image(
|
||||||
|
image=img,
|
||||||
|
do_resize=do_resize,
|
||||||
|
size=size,
|
||||||
|
resample=resample,
|
||||||
|
do_rescale=do_rescale,
|
||||||
|
rescale_factor=rescale_factor,
|
||||||
|
do_center_crop=do_center_crop,
|
||||||
|
crop_size=crop_size,
|
||||||
|
do_flip_channel_order=do_flip_channel_order,
|
||||||
|
data_format=data_format,
|
||||||
|
input_data_format=input_data_format,
|
||||||
|
)
|
||||||
|
for img in images
|
||||||
]
|
]
|
||||||
|
|
||||||
data = {"pixel_values": images}
|
data = {"pixel_values": images}
|
||||||
|
|
||||||
|
if segmentation_maps is not None:
|
||||||
|
segmentation_maps = [
|
||||||
|
self._preprocess_mask(
|
||||||
|
segmentation_map=segmentation_map,
|
||||||
|
do_resize=do_resize,
|
||||||
|
size=size,
|
||||||
|
do_center_crop=do_center_crop,
|
||||||
|
crop_size=crop_size,
|
||||||
|
input_data_format=input_data_format,
|
||||||
|
)
|
||||||
|
for segmentation_map in segmentation_maps
|
||||||
|
]
|
||||||
|
|
||||||
|
data["labels"] = segmentation_maps
|
||||||
|
|
||||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||||
|
|
||||||
# Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->MobileViT
|
# Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->MobileViT
|
||||||
|
|||||||
@@ -16,13 +16,20 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
from transformers.testing_utils import require_torch, require_vision
|
from transformers.testing_utils import require_torch, require_vision
|
||||||
from transformers.utils import is_vision_available
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from transformers import MobileViTImageProcessor
|
from transformers import MobileViTImageProcessor
|
||||||
|
|
||||||
|
|
||||||
@@ -79,6 +86,26 @@ class MobileViTImageProcessingTester(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_semantic_single_inputs():
|
||||||
|
dataset = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||||
|
|
||||||
|
image = Image.open(dataset[0]["file"])
|
||||||
|
map = Image.open(dataset[1]["file"])
|
||||||
|
|
||||||
|
return image, map
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_semantic_batch_inputs():
|
||||||
|
dataset = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||||
|
|
||||||
|
image1 = Image.open(dataset[0]["file"])
|
||||||
|
map1 = Image.open(dataset[1]["file"])
|
||||||
|
image2 = Image.open(dataset[2]["file"])
|
||||||
|
map2 = Image.open(dataset[3]["file"])
|
||||||
|
|
||||||
|
return [image1, image2], [map1, map2]
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_vision
|
@require_vision
|
||||||
class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||||
@@ -107,3 +134,109 @@ class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
|
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
|
||||||
self.assertEqual(image_processor.size, {"shortest_edge": 42})
|
self.assertEqual(image_processor.size, {"shortest_edge": 42})
|
||||||
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
|
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
|
||||||
|
|
||||||
|
def test_call_segmentation_maps(self):
|
||||||
|
# Initialize image_processing
|
||||||
|
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||||
|
# create random PyTorch tensors
|
||||||
|
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||||
|
maps = []
|
||||||
|
for image in image_inputs:
|
||||||
|
self.assertIsInstance(image, torch.Tensor)
|
||||||
|
maps.append(torch.zeros(image.shape[-2:]).long())
|
||||||
|
|
||||||
|
# Test not batched input
|
||||||
|
encoding = image_processing(image_inputs[0], maps[0], return_tensors="pt")
|
||||||
|
self.assertEqual(
|
||||||
|
encoding["pixel_values"].shape,
|
||||||
|
(
|
||||||
|
1,
|
||||||
|
self.image_processor_tester.num_channels,
|
||||||
|
self.image_processor_tester.crop_size["height"],
|
||||||
|
self.image_processor_tester.crop_size["width"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
encoding["labels"].shape,
|
||||||
|
(
|
||||||
|
1,
|
||||||
|
self.image_processor_tester.crop_size["height"],
|
||||||
|
self.image_processor_tester.crop_size["width"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||||
|
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||||
|
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||||
|
|
||||||
|
# Test batched
|
||||||
|
encoding = image_processing(image_inputs, maps, return_tensors="pt")
|
||||||
|
self.assertEqual(
|
||||||
|
encoding["pixel_values"].shape,
|
||||||
|
(
|
||||||
|
self.image_processor_tester.batch_size,
|
||||||
|
self.image_processor_tester.num_channels,
|
||||||
|
self.image_processor_tester.crop_size["height"],
|
||||||
|
self.image_processor_tester.crop_size["width"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
encoding["labels"].shape,
|
||||||
|
(
|
||||||
|
self.image_processor_tester.batch_size,
|
||||||
|
self.image_processor_tester.crop_size["height"],
|
||||||
|
self.image_processor_tester.crop_size["width"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||||
|
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||||
|
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||||
|
|
||||||
|
# Test not batched input (PIL images)
|
||||||
|
image, segmentation_map = prepare_semantic_single_inputs()
|
||||||
|
|
||||||
|
encoding = image_processing(image, segmentation_map, return_tensors="pt")
|
||||||
|
self.assertEqual(
|
||||||
|
encoding["pixel_values"].shape,
|
||||||
|
(
|
||||||
|
1,
|
||||||
|
self.image_processor_tester.num_channels,
|
||||||
|
self.image_processor_tester.crop_size["height"],
|
||||||
|
self.image_processor_tester.crop_size["width"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
encoding["labels"].shape,
|
||||||
|
(
|
||||||
|
1,
|
||||||
|
self.image_processor_tester.crop_size["height"],
|
||||||
|
self.image_processor_tester.crop_size["width"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||||
|
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||||
|
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||||
|
|
||||||
|
# Test batched input (PIL images)
|
||||||
|
images, segmentation_maps = prepare_semantic_batch_inputs()
|
||||||
|
|
||||||
|
encoding = image_processing(images, segmentation_maps, return_tensors="pt")
|
||||||
|
self.assertEqual(
|
||||||
|
encoding["pixel_values"].shape,
|
||||||
|
(
|
||||||
|
2,
|
||||||
|
self.image_processor_tester.num_channels,
|
||||||
|
self.image_processor_tester.crop_size["height"],
|
||||||
|
self.image_processor_tester.crop_size["width"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
encoding["labels"].shape,
|
||||||
|
(
|
||||||
|
2,
|
||||||
|
self.image_processor_tester.crop_size["height"],
|
||||||
|
self.image_processor_tester.crop_size["width"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||||
|
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||||
|
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||||
|
|||||||
Reference in New Issue
Block a user