[SegGPT] Fix seggpt image processor (#29550)
* Fixed SegGptImageProcessor to handle 2D and 3D prompt mask inputs * Added new test to check prompt mask equivalence * New proposal * Better proposal * Removed unnecessary method * Updated seggpt docs * Introduced do_convert_rgb * nits
This commit is contained in:
@@ -30,6 +30,8 @@ if is_torch_available():
|
||||
from transformers.models.seggpt.modeling_seggpt import SegGptImageSegmentationOutput
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import SegGptImageProcessor
|
||||
|
||||
|
||||
@@ -147,7 +149,7 @@ class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
mask_rgb = mask_binary.convert("RGB")
|
||||
|
||||
inputs_binary = image_processor(images=None, prompt_masks=mask_binary, return_tensors="pt")
|
||||
inputs_rgb = image_processor(images=None, prompt_masks=mask_rgb, return_tensors="pt")
|
||||
inputs_rgb = image_processor(images=None, prompt_masks=mask_rgb, return_tensors="pt", do_convert_rgb=False)
|
||||
|
||||
self.assertTrue((inputs_binary["prompt_masks"] == inputs_rgb["prompt_masks"]).all().item())
|
||||
|
||||
@@ -196,7 +198,11 @@ class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processor = SegGptImageProcessor.from_pretrained("BAAI/seggpt-vit-large")
|
||||
|
||||
inputs = image_processor(
|
||||
images=input_image, prompt_images=prompt_image, prompt_masks=prompt_mask, return_tensors="pt"
|
||||
images=input_image,
|
||||
prompt_images=prompt_image,
|
||||
prompt_masks=prompt_mask,
|
||||
return_tensors="pt",
|
||||
do_convert_rgb=False,
|
||||
)
|
||||
|
||||
# Verify pixel values
|
||||
@@ -229,3 +235,76 @@ class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
torch.allclose(inputs.prompt_pixel_values[0, :, :3, :3], expected_prompt_pixel_values, atol=1e-4)
|
||||
)
|
||||
self.assertTrue(torch.allclose(inputs.prompt_masks[0, :, :3, :3], expected_prompt_masks, atol=1e-4))
|
||||
|
||||
def test_prompt_mask_equivalence(self):
|
||||
image_processor = self.image_processing_class(**self.image_processor_dict)
|
||||
image_size = self.image_processor_tester.image_size
|
||||
|
||||
# Single Mask Examples
|
||||
expected_single_shape = [1, 3, image_size, image_size]
|
||||
|
||||
# Single Semantic Map (2D)
|
||||
image_np_2d = np.ones((image_size, image_size))
|
||||
image_pt_2d = torch.ones((image_size, image_size))
|
||||
image_pil_2d = Image.fromarray(image_np_2d)
|
||||
|
||||
inputs_np_2d = image_processor(images=None, prompt_masks=image_np_2d, return_tensors="pt")
|
||||
inputs_pt_2d = image_processor(images=None, prompt_masks=image_pt_2d, return_tensors="pt")
|
||||
inputs_pil_2d = image_processor(images=None, prompt_masks=image_pil_2d, return_tensors="pt")
|
||||
|
||||
self.assertTrue((inputs_np_2d["prompt_masks"] == inputs_pt_2d["prompt_masks"]).all().item())
|
||||
self.assertTrue((inputs_np_2d["prompt_masks"] == inputs_pil_2d["prompt_masks"]).all().item())
|
||||
self.assertEqual(list(inputs_np_2d["prompt_masks"].shape), expected_single_shape)
|
||||
|
||||
# Single RGB Images (3D)
|
||||
image_np_3d = np.ones((3, image_size, image_size))
|
||||
image_pt_3d = torch.ones((3, image_size, image_size))
|
||||
image_pil_3d = Image.fromarray(image_np_3d.transpose(1, 2, 0).astype(np.uint8))
|
||||
|
||||
inputs_np_3d = image_processor(
|
||||
images=None, prompt_masks=image_np_3d, return_tensors="pt", do_convert_rgb=False
|
||||
)
|
||||
inputs_pt_3d = image_processor(
|
||||
images=None, prompt_masks=image_pt_3d, return_tensors="pt", do_convert_rgb=False
|
||||
)
|
||||
inputs_pil_3d = image_processor(
|
||||
images=None, prompt_masks=image_pil_3d, return_tensors="pt", do_convert_rgb=False
|
||||
)
|
||||
|
||||
self.assertTrue((inputs_np_3d["prompt_masks"] == inputs_pt_3d["prompt_masks"]).all().item())
|
||||
self.assertTrue((inputs_np_3d["prompt_masks"] == inputs_pil_3d["prompt_masks"]).all().item())
|
||||
self.assertEqual(list(inputs_np_3d["prompt_masks"].shape), expected_single_shape)
|
||||
|
||||
# Batched Examples
|
||||
expected_batched_shape = [2, 3, image_size, image_size]
|
||||
|
||||
# Batched Semantic Maps (3D)
|
||||
image_np_2d_batched = np.ones((2, image_size, image_size))
|
||||
image_pt_2d_batched = torch.ones((2, image_size, image_size))
|
||||
|
||||
inputs_np_2d_batched = image_processor(images=None, prompt_masks=image_np_2d_batched, return_tensors="pt")
|
||||
inputs_pt_2d_batched = image_processor(images=None, prompt_masks=image_pt_2d_batched, return_tensors="pt")
|
||||
|
||||
self.assertTrue((inputs_np_2d_batched["prompt_masks"] == inputs_pt_2d_batched["prompt_masks"]).all().item())
|
||||
self.assertEqual(list(inputs_np_2d_batched["prompt_masks"].shape), expected_batched_shape)
|
||||
|
||||
# Batched RGB images
|
||||
image_np_4d = np.ones((2, 3, image_size, image_size))
|
||||
image_pt_4d = torch.ones((2, 3, image_size, image_size))
|
||||
|
||||
inputs_np_4d = image_processor(
|
||||
images=None, prompt_masks=image_np_4d, return_tensors="pt", do_convert_rgb=False
|
||||
)
|
||||
inputs_pt_4d = image_processor(
|
||||
images=None, prompt_masks=image_pt_4d, return_tensors="pt", do_convert_rgb=False
|
||||
)
|
||||
|
||||
self.assertTrue((inputs_np_4d["prompt_masks"] == inputs_pt_4d["prompt_masks"]).all().item())
|
||||
self.assertEqual(list(inputs_np_4d["prompt_masks"].shape), expected_batched_shape)
|
||||
|
||||
# Comparing Single and Batched Examples
|
||||
self.assertTrue((inputs_np_2d["prompt_masks"][0] == inputs_np_3d["prompt_masks"][0]).all().item())
|
||||
self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_2d["prompt_masks"][0]).all().item())
|
||||
self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_3d["prompt_masks"][0]).all().item())
|
||||
self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_4d["prompt_masks"][0]).all().item())
|
||||
self.assertTrue((inputs_np_2d_batched["prompt_masks"][0] == inputs_np_3d["prompt_masks"][0]).all().item())
|
||||
|
||||
Reference in New Issue
Block a user