Add segmentation_maps support to MobileNetV2ImageProcessor (#37312)

* Add `segmentation_maps` support to mobilenet_v2 image processor and `reduce_labels` to mobilevit

* Changed mobilenetv2 tests to support fastimageprocessor

* added `segmentation_maps` support to fast image processor

* reverted to upstream/main

* Add optional

* Use autodocstring

* Changed docs

* Docs fix

* Changed fp to match beit fp

* Change typing imports

* Fixed repo inconsistency

* Added fast-slow equivalence tests

* Removed unnecessary call

* Add `reduce_labels` to Mobilevit fast processor

---------

Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
Mikhail Moskovchenko
2025-07-07 22:34:59 +05:00
committed by GitHub
parent b96f213fcf
commit 3993ee1e98
7 changed files with 723 additions and 52 deletions

View File

@@ -15,13 +15,21 @@
import unittest
import requests
from datasets import load_dataset
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torchvision_available, is_vision_available
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
from transformers import MobileNetV2ImageProcessor
if is_torchvision_available():
@@ -41,6 +49,7 @@ class MobileNetV2ImageProcessingTester:
size=None,
do_center_crop=True,
crop_size=None,
do_reduce_labels=False,
):
size = size if size is not None else {"shortest_edge": 20}
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
@@ -54,6 +63,7 @@ class MobileNetV2ImageProcessingTester:
self.size = size
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_reduce_labels = do_reduce_labels
def prepare_image_processor_dict(self):
return {
@@ -61,6 +71,7 @@ class MobileNetV2ImageProcessingTester:
"size": self.size,
"do_center_crop": self.do_center_crop,
"crop_size": self.crop_size,
"do_reduce_labels": self.do_reduce_labels,
}
def expected_output_image_shape(self, images):
@@ -78,6 +89,17 @@ class MobileNetV2ImageProcessingTester:
)
def prepare_semantic_single_inputs():
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
example = ds[0]
return example["image"], example["map"]
def prepare_semantic_batch_inputs():
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
return list(ds["image"][:2]), list(ds["map"][:2])
@require_torch
@require_vision
class MobileNetV2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
@@ -99,13 +121,167 @@ class MobileNetV2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
self.assertTrue(hasattr(image_processor, "size"))
self.assertTrue(hasattr(image_processor, "do_center_crop"))
self.assertTrue(hasattr(image_processor, "crop_size"))
self.assertTrue(hasattr(image_processor, "do_reduce_labels"))
def test_image_processor_from_dict_with_kwargs(self):
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 20})
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
self.assertEqual(image_processor.do_reduce_labels, False)
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
image_processor = image_processing_class.from_dict(
self.image_processor_dict, size=42, crop_size=84, do_reduce_labels=True
)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
self.assertEqual(image_processor.do_reduce_labels, True)
def test_call_segmentation_maps(self):
# Initialize image_processing
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
maps = []
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
maps.append(torch.zeros(image.shape[-2:]).long())
# Test not batched input
encoding = image_processing(image_inputs[0], maps[0], return_tensors="pt")
self.assertEqual(
encoding["pixel_values"].shape,
(
1,
self.image_processor_tester.num_channels,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
1,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
# Test batched
encoding = image_processing(image_inputs, maps, return_tensors="pt")
self.assertEqual(
encoding["pixel_values"].shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
# Test not batched input (PIL images)
image, segmentation_map = prepare_semantic_single_inputs()
encoding = image_processing(image, segmentation_map, return_tensors="pt")
self.assertEqual(
encoding["pixel_values"].shape,
(
1,
self.image_processor_tester.num_channels,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
1,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
# Test batched input (PIL images)
images, segmentation_maps = prepare_semantic_batch_inputs()
encoding = image_processing(images, segmentation_maps, return_tensors="pt")
self.assertEqual(
encoding["pixel_values"].shape,
(
2,
self.image_processor_tester.num_channels,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
2,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
def test_reduce_labels(self):
# Initialize image_processing
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
# ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150
image, map = prepare_semantic_single_inputs()
encoding = image_processing(image, map, return_tensors="pt")
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 150)
image_processing.do_reduce_labels = True
encoding = image_processing(image, map, return_tensors="pt")
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
def test_slow_fast_equivalence(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
# Test with single image
dummy_image = Image.open(
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
)
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
# Test with single image and segmentation map
image, segmentation_map = prepare_semantic_single_inputs()
encoding_slow = image_processor_slow(image, segmentation_map, return_tensors="pt")
encoding_fast = image_processor_fast(image, segmentation_map, return_tensors="pt")
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
torch.testing.assert_close(encoding_slow.labels, encoding_fast.labels, atol=1e-1, rtol=1e-3)

View File

@@ -50,6 +50,7 @@ class MobileViTImageProcessingTester:
do_center_crop=True,
crop_size=None,
do_flip_channel_order=True,
do_reduce_labels=False,
):
size = size if size is not None else {"shortest_edge": 20}
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
@@ -64,6 +65,7 @@ class MobileViTImageProcessingTester:
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_flip_channel_order = do_flip_channel_order
self.do_reduce_labels = do_reduce_labels
def prepare_image_processor_dict(self):
return {
@@ -72,6 +74,7 @@ class MobileViTImageProcessingTester:
"do_center_crop": self.do_center_crop,
"crop_size": self.crop_size,
"do_flip_channel_order": self.do_flip_channel_order,
"do_reduce_labels": self.do_reduce_labels,
}
def expected_output_image_shape(self, images):
@@ -122,16 +125,21 @@ class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
self.assertTrue(hasattr(image_processing, "do_center_crop"))
self.assertTrue(hasattr(image_processing, "center_crop"))
self.assertTrue(hasattr(image_processing, "do_flip_channel_order"))
self.assertTrue(hasattr(image_processing, "do_reduce_labels"))
def test_image_processor_from_dict_with_kwargs(self):
for image_processing_class in self.image_processor_list:
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 20})
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
self.assertEqual(image_processor.do_reduce_labels, False)
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
image_processor = self.image_processing_class.from_dict(
self.image_processor_dict, size=42, crop_size=84, do_reduce_labels=True
)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
self.assertEqual(image_processor.do_reduce_labels, True)
def test_call_segmentation_maps(self):
for image_processing_class in self.image_processor_list:
@@ -240,6 +248,22 @@ class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
def test_reduce_labels(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150
image, map = prepare_semantic_single_inputs()
encoding = image_processing(image, map, return_tensors="pt")
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 150)
image_processing.do_reduce_labels = True
encoding = image_processing(image, map, return_tensors="pt")
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
@require_vision
@require_torch
def test_slow_fast_equivalence(self):