[SegGPT] Fix seggpt image processor (#29550)

* Fixed SegGptImageProcessor to handle 2D and 3D prompt mask inputs

* Added new test to check prompt mask equivalence

* New proposal

* Better proposal

* Removed unnecessary method

* Updated seggpt docs

* Introduced do_convert_rgb

* nits
This commit is contained in:
Eduardo Pacheco
2024-04-26 20:40:12 +02:00
committed by GitHub
parent c793b26f2e
commit 6d4cabda26
4 changed files with 148 additions and 65 deletions

View File

@@ -30,6 +30,8 @@ if is_torch_available():
from transformers.models.seggpt.modeling_seggpt import SegGptImageSegmentationOutput
if is_vision_available():
from PIL import Image
from transformers import SegGptImageProcessor
@@ -147,7 +149,7 @@ class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
mask_rgb = mask_binary.convert("RGB")
inputs_binary = image_processor(images=None, prompt_masks=mask_binary, return_tensors="pt")
inputs_rgb = image_processor(images=None, prompt_masks=mask_rgb, return_tensors="pt")
inputs_rgb = image_processor(images=None, prompt_masks=mask_rgb, return_tensors="pt", do_convert_rgb=False)
self.assertTrue((inputs_binary["prompt_masks"] == inputs_rgb["prompt_masks"]).all().item())
@@ -196,7 +198,11 @@ class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processor = SegGptImageProcessor.from_pretrained("BAAI/seggpt-vit-large")
inputs = image_processor(
images=input_image, prompt_images=prompt_image, prompt_masks=prompt_mask, return_tensors="pt"
images=input_image,
prompt_images=prompt_image,
prompt_masks=prompt_mask,
return_tensors="pt",
do_convert_rgb=False,
)
# Verify pixel values
@@ -229,3 +235,76 @@ class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
torch.allclose(inputs.prompt_pixel_values[0, :, :3, :3], expected_prompt_pixel_values, atol=1e-4)
)
self.assertTrue(torch.allclose(inputs.prompt_masks[0, :, :3, :3], expected_prompt_masks, atol=1e-4))
def test_prompt_mask_equivalence(self):
image_processor = self.image_processing_class(**self.image_processor_dict)
image_size = self.image_processor_tester.image_size
# Single Mask Examples
expected_single_shape = [1, 3, image_size, image_size]
# Single Semantic Map (2D)
image_np_2d = np.ones((image_size, image_size))
image_pt_2d = torch.ones((image_size, image_size))
image_pil_2d = Image.fromarray(image_np_2d)
inputs_np_2d = image_processor(images=None, prompt_masks=image_np_2d, return_tensors="pt")
inputs_pt_2d = image_processor(images=None, prompt_masks=image_pt_2d, return_tensors="pt")
inputs_pil_2d = image_processor(images=None, prompt_masks=image_pil_2d, return_tensors="pt")
self.assertTrue((inputs_np_2d["prompt_masks"] == inputs_pt_2d["prompt_masks"]).all().item())
self.assertTrue((inputs_np_2d["prompt_masks"] == inputs_pil_2d["prompt_masks"]).all().item())
self.assertEqual(list(inputs_np_2d["prompt_masks"].shape), expected_single_shape)
# Single RGB Images (3D)
image_np_3d = np.ones((3, image_size, image_size))
image_pt_3d = torch.ones((3, image_size, image_size))
image_pil_3d = Image.fromarray(image_np_3d.transpose(1, 2, 0).astype(np.uint8))
inputs_np_3d = image_processor(
images=None, prompt_masks=image_np_3d, return_tensors="pt", do_convert_rgb=False
)
inputs_pt_3d = image_processor(
images=None, prompt_masks=image_pt_3d, return_tensors="pt", do_convert_rgb=False
)
inputs_pil_3d = image_processor(
images=None, prompt_masks=image_pil_3d, return_tensors="pt", do_convert_rgb=False
)
self.assertTrue((inputs_np_3d["prompt_masks"] == inputs_pt_3d["prompt_masks"]).all().item())
self.assertTrue((inputs_np_3d["prompt_masks"] == inputs_pil_3d["prompt_masks"]).all().item())
self.assertEqual(list(inputs_np_3d["prompt_masks"].shape), expected_single_shape)
# Batched Examples
expected_batched_shape = [2, 3, image_size, image_size]
# Batched Semantic Maps (3D)
image_np_2d_batched = np.ones((2, image_size, image_size))
image_pt_2d_batched = torch.ones((2, image_size, image_size))
inputs_np_2d_batched = image_processor(images=None, prompt_masks=image_np_2d_batched, return_tensors="pt")
inputs_pt_2d_batched = image_processor(images=None, prompt_masks=image_pt_2d_batched, return_tensors="pt")
self.assertTrue((inputs_np_2d_batched["prompt_masks"] == inputs_pt_2d_batched["prompt_masks"]).all().item())
self.assertEqual(list(inputs_np_2d_batched["prompt_masks"].shape), expected_batched_shape)
# Batched RGB images
image_np_4d = np.ones((2, 3, image_size, image_size))
image_pt_4d = torch.ones((2, 3, image_size, image_size))
inputs_np_4d = image_processor(
images=None, prompt_masks=image_np_4d, return_tensors="pt", do_convert_rgb=False
)
inputs_pt_4d = image_processor(
images=None, prompt_masks=image_pt_4d, return_tensors="pt", do_convert_rgb=False
)
self.assertTrue((inputs_np_4d["prompt_masks"] == inputs_pt_4d["prompt_masks"]).all().item())
self.assertEqual(list(inputs_np_4d["prompt_masks"].shape), expected_batched_shape)
# Comparing Single and Batched Examples
self.assertTrue((inputs_np_2d["prompt_masks"][0] == inputs_np_3d["prompt_masks"][0]).all().item())
self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_2d["prompt_masks"][0]).all().item())
self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_3d["prompt_masks"][0]).all().item())
self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_4d["prompt_masks"][0]).all().item())
self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_3d["prompt_masks"][0]).all().item())