Optionally preprocess segmentation maps for MobileViT (#28420)

* optionally preprocess segmentation maps for mobilevit

* changed pretrained model name to that of segmentation model

* removed voc-deeplabv3 from model archive list

* added preprocess_image and preprocess_mask methods for processing images and segmentation masks respectively

* added tests for segmentation masks based on segformer feature extractor

* use crop_size instead of size

* reverting to initial model
This commit is contained in:
Harisankar Babu
2024-01-11 15:52:14 +01:00
committed by GitHub
parent 95091e1582
commit d560637885
2 changed files with 292 additions and 42 deletions

View File

@@ -16,13 +16,20 @@
import unittest
from datasets import load_dataset
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from transformers.utils import is_torch_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
from transformers import MobileViTImageProcessor
@@ -79,6 +86,26 @@ class MobileViTImageProcessingTester(unittest.TestCase):
)
def prepare_semantic_single_inputs():
dataset = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
image = Image.open(dataset[0]["file"])
map = Image.open(dataset[1]["file"])
return image, map
def prepare_semantic_batch_inputs():
dataset = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
image1 = Image.open(dataset[0]["file"])
map1 = Image.open(dataset[1]["file"])
image2 = Image.open(dataset[2]["file"])
map2 = Image.open(dataset[3]["file"])
return [image1, image2], [map1, map2]
@require_torch
@require_vision
class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
@@ -107,3 +134,109 @@ class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
def test_call_segmentation_maps(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
maps = []
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
maps.append(torch.zeros(image.shape[-2:]).long())
# Test not batched input
encoding = image_processing(image_inputs[0], maps[0], return_tensors="pt")
self.assertEqual(
encoding["pixel_values"].shape,
(
1,
self.image_processor_tester.num_channels,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
1,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
# Test batched
encoding = image_processing(image_inputs, maps, return_tensors="pt")
self.assertEqual(
encoding["pixel_values"].shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
# Test not batched input (PIL images)
image, segmentation_map = prepare_semantic_single_inputs()
encoding = image_processing(image, segmentation_map, return_tensors="pt")
self.assertEqual(
encoding["pixel_values"].shape,
(
1,
self.image_processor_tester.num_channels,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
1,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
# Test batched input (PIL images)
images, segmentation_maps = prepare_semantic_batch_inputs()
encoding = image_processing(images, segmentation_maps, return_tensors="pt")
self.assertEqual(
encoding["pixel_values"].shape,
(
2,
self.image_processor_tester.num_channels,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
2,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)