From e4227eb4d4cc1ac279865a12b52f957b818616f7 Mon Sep 17 00:00:00 2001 From: Marcel <46164444+MSt-10@users.noreply.github.com> Date: Thu, 30 Jan 2025 20:08:38 +0100 Subject: [PATCH] Handle empty change indices in SAM's mask to rle conversion (#35665) * Handle empty change indices in RLE conversion for masks * [test] Add unit tests for RLE encoding of masks in SamProcessor * [test] Update RLE conversion tests to use TensorFlow implementation * [test] Fix formatting in SamProcessorTest according to check_code_quality action * [test] Fix formatting in SamProcessorTest according to check_code_quality * [test] Refactored rle test cases into one test and used tf tensors in tf test cases * [test] Fix: removed self parameter from refactored methods * [test] Removed nested methods in run-length encoding tests for PyTorch and TensorFlow * [test] Added description to individual to run-length encoding tests for PyTorch and TensorFlow. --- .../models/sam/image_processing_sam.py | 16 ++++ tests/models/sam/test_processor_sam.py | 76 +++++++++++++++++++ 2 files changed, 92 insertions(+) diff --git a/src/transformers/models/sam/image_processing_sam.py b/src/transformers/models/sam/image_processing_sam.py index a1cace89fa..ae643b367f 100644 --- a/src/transformers/models/sam/image_processing_sam.py +++ b/src/transformers/models/sam/image_processing_sam.py @@ -1373,6 +1373,14 @@ def _mask_to_rle_pytorch(input_mask: "torch.Tensor"): out = [] for i in range(batch_size): cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1 + if len(cur_idxs) == 0: + # No changes => either all 0 or all 1 + # If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width]. + if input_mask[i, 0] == 0: + out.append({"size": [height, width], "counts": [height * width]}) + else: + out.append({"size": [height, width], "counts": [0, height * width]}) + continue btw_idxs = cur_idxs[1:] - cur_idxs[:-1] counts = [] if input_mask[i, 0] == 0 else [0] counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]] @@ -1396,6 +1404,14 @@ def _mask_to_rle_tf(input_mask: "tf.Tensor"): out = [] for i in range(batch_size): cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1 + if len(cur_idxs) == 0: + # No changes => either all 0 or all 1 + # If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width]. + if input_mask[i, 0] == 0: + out.append({"size": [height, width], "counts": [height * width]}) + else: + out.append({"size": [height, width], "counts": [0, height * width]}) + continue btw_idxs = cur_idxs[1:] - cur_idxs[:-1] counts = [] if input_mask[i, 0] == 0 else [0] counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]] diff --git a/tests/models/sam/test_processor_sam.py b/tests/models/sam/test_processor_sam.py index 654f892062..3a2814f8f4 100644 --- a/tests/models/sam/test_processor_sam.py +++ b/tests/models/sam/test_processor_sam.py @@ -37,9 +37,13 @@ if is_vision_available(): if is_torch_available(): import torch + from transformers.models.sam.image_processing_sam import _mask_to_rle_pytorch + if is_tf_available(): import tensorflow as tf + from transformers.models.sam.image_processing_sam import _mask_to_rle_tf + @require_vision @require_torchvision @@ -161,6 +165,42 @@ class SamProcessorTest(ProcessorTesterMixin, unittest.TestCase): with self.assertRaises(ValueError): masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size)) + def test_rle_encoding(self): + """ + Test the run-length encoding function. + """ + # Test that a mask of all zeros returns a single run [height * width]. + input_mask = torch.zeros((1, 2, 2), dtype=torch.long) # shape: 1 x 2 x 2 + rle = _mask_to_rle_pytorch(input_mask) + + self.assertEqual(len(rle), 1) + self.assertEqual(rle[0]["size"], [2, 2]) + # For a 2x2 all-zero mask, we expect a single run of length 4: + self.assertEqual(rle[0]["counts"], [4]) + + # Test that a mask of all ones returns [0, height * width]. + input_mask = torch.ones((1, 2, 2), dtype=torch.long) # shape: 1 x 2 x 2 + rle = _mask_to_rle_pytorch(input_mask) + + self.assertEqual(len(rle), 1) + self.assertEqual(rle[0]["size"], [2, 2]) + # For a 2x2 all-one mask, we expect two runs: [0, 4]. + self.assertEqual(rle[0]["counts"], [0, 4]) + + # Test a mask with mixed 0s and 1s to ensure the run-length encoding is correct. + # Example mask: + # Row 0: [0, 1] + # Row 1: [1, 1] + # This is shape (1, 2, 2). + # Flattened in Fortran order -> [0, 1, 1, 1]. + # The RLE for [0,1,1,1] is [1, 3]. + input_mask = torch.tensor([[[0, 1], [1, 1]]], dtype=torch.long) + rle = _mask_to_rle_pytorch(input_mask) + + self.assertEqual(len(rle), 1) + self.assertEqual(rle[0]["size"], [2, 2]) + self.assertEqual(rle[0]["counts"], [1, 3]) # 1 zero, followed by 3 ones + @require_vision @require_tf @@ -244,6 +284,42 @@ class TFSamProcessorTest(unittest.TestCase): dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf" ) + def test_rle_encoding(self): + """ + Test the run-length encoding function. + """ + # Test that a mask of all zeros returns a single run [height * width]. + input_mask = tf.zeros((1, 2, 2), dtype=tf.int64) # shape: 1 x 2 x 2 + rle = _mask_to_rle_tf(input_mask) + + self.assertEqual(len(rle), 1) + self.assertEqual(rle[0]["size"], [2, 2]) + # For a 2x2 all-zero mask, we expect a single run of length 4: + self.assertEqual(rle[0]["counts"], [4]) + + # Test that a mask of all ones returns [0, height * width]. + input_mask = tf.ones((1, 2, 2), dtype=tf.int64) # shape: 1 x 2 x 2 + rle = _mask_to_rle_tf(input_mask) + + self.assertEqual(len(rle), 1) + self.assertEqual(rle[0]["size"], [2, 2]) + # For a 2x2 all-one mask, we expect two runs: [0, 4]. + self.assertEqual(rle[0]["counts"], [0, 4]) + + # Test a mask with mixed 0s and 1s to ensure the run-length encoding is correct. + # Example mask: + # Row 0: [0, 1] + # Row 1: [1, 1] + # This is shape (1, 2, 2). + # Flattened in Fortran order -> [0, 1, 1, 1]. + # The RLE for [0,1,1,1] is [1, 3]. + input_mask = tf.tensor([[[0, 1], [1, 1]]], dtype=tf.int64) + rle = _mask_to_rle_tf(input_mask) + + self.assertEqual(len(rle), 1) + self.assertEqual(rle[0]["size"], [2, 2]) + self.assertEqual(rle[0]["counts"], [1, 3]) # 1 zero, followed by 3 ones + @require_vision @require_torchvision