Optionally preprocess segmentation maps for MobileViT (#28420)
* optionally preprocess segmentation maps for mobilevit * changed pretrained model name to that of segmentation model * removed voc-deeplabv3 from model archive list * added preprocess_image and preprocess_mask methods for processing images and segmentation masks respectively * added tests for segmentation masks based on segformer feature extractor * use crop_size instead of size * reverting to initial model
This commit is contained in:
@@ -16,13 +16,20 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import MobileViTImageProcessor
|
||||
|
||||
|
||||
@@ -79,6 +86,26 @@ class MobileViTImageProcessingTester(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
def prepare_semantic_single_inputs():
|
||||
dataset = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||
|
||||
image = Image.open(dataset[0]["file"])
|
||||
map = Image.open(dataset[1]["file"])
|
||||
|
||||
return image, map
|
||||
|
||||
|
||||
def prepare_semantic_batch_inputs():
|
||||
dataset = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||
|
||||
image1 = Image.open(dataset[0]["file"])
|
||||
map1 = Image.open(dataset[1]["file"])
|
||||
image2 = Image.open(dataset[2]["file"])
|
||||
map2 = Image.open(dataset[3]["file"])
|
||||
|
||||
return [image1, image2], [map1, map2]
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
@@ -107,3 +134,109 @@ class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 42})
|
||||
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
|
||||
|
||||
def test_call_segmentation_maps(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||
maps = []
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
maps.append(torch.zeros(image.shape[-2:]).long())
|
||||
|
||||
# Test not batched input
|
||||
encoding = image_processing(image_inputs[0], maps[0], return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
# Test batched
|
||||
encoding = image_processing(image_inputs, maps, return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
self.image_processor_tester.batch_size,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
self.image_processor_tester.batch_size,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
# Test not batched input (PIL images)
|
||||
image, segmentation_map = prepare_semantic_single_inputs()
|
||||
|
||||
encoding = image_processing(image, segmentation_map, return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
# Test batched input (PIL images)
|
||||
images, segmentation_maps = prepare_semantic_batch_inputs()
|
||||
|
||||
encoding = image_processing(images, segmentation_maps, return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding["pixel_values"].shape,
|
||||
(
|
||||
2,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
encoding["labels"].shape,
|
||||
(
|
||||
2,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.assertEqual(encoding["labels"].dtype, torch.long)
|
||||
self.assertTrue(encoding["labels"].min().item() >= 0)
|
||||
self.assertTrue(encoding["labels"].max().item() <= 255)
|
||||
|
||||
Reference in New Issue
Block a user