From a1835195d134f5a244aed1212342be94fa27b40c Mon Sep 17 00:00:00 2001 From: StevenBucaille Date: Tue, 29 Oct 2024 10:36:03 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=F0=9F=9A=A8=F0=9F=9A=A8=20[SuperPo?= =?UTF-8?q?int]=20Fix=20keypoint=20coordinate=20output=20and=20add=20post?= =?UTF-8?q?=20processing=20(#33200)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: Added int conversion and unwrapping * test: added tests for post_process_keypoint_detection of SuperPointImageProcessor * docs: changed docs to include post_process_keypoint_detection method and switched from opencv to matplotlib * test: changed test to not depend on SuperPointModel forward * test: added missing require_torch decorator * docs: changed pyplot parameters for the keypoints to be more visible in the example * tests: changed import torch location to make test_flax and test_tf * Revert "tests: changed import torch location to make test_flax and test_tf" This reverts commit 39b32a2f69500bc7af01715fc7beae2260549afe. * tests: fixed import * chore: applied suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * tests: fixed import * tests: fixed import (bis) * tests: fixed import (ter) * feat: added choice of type for target_size and changed tests accordingly * docs: updated code snippet to reflect the addition of target size type choice in post process method * tests: fixed imports (...) * tests: fixed imports (...) * style: formatting file * docs: fixed typo from image[0] to image.size[0] * docs: added output image and fixed some tests * Update docs/source/en/model_doc/superpoint.md Co-authored-by: Pavel Iakubovskii * fix: included SuperPointKeypointDescriptionOutput in TYPE_CHECKING if statement and changed tests results to reflect changes to SuperPoint from absolute keypoints coordinates to relative * docs: changed SuperPoint's docs to print output instead of just accessing * style: applied make style * docs: added missing output type and precision in docstring of post_process_keypoint_detection * perf: deleted loop to perform keypoint conversion in one statement * fix: moved keypoint conversion at the end of model forward * docs: changed SuperPointInterestPointDecoder to SuperPointKeypointDecoder class name and added relative (x, y) coordinates information to its method * fix: changed type hint * refactor: removed unnecessary brackets * revert: SuperPointKeypointDecoder to SuperPointInterestPointDecoder * Update docs/source/en/model_doc/superpoint.md Co-authored-by: Pavel Iakubovskii --------- Co-authored-by: Steven Bucaille Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Pavel Iakubovskii --- docs/source/en/model_doc/superpoint.md | 35 +++++++---- .../superpoint/image_processing_superpoint.py | 59 ++++++++++++++++++- .../models/superpoint/modeling_superpoint.py | 10 +++- .../test_image_processing_superpoint.py | 54 ++++++++++++++++- .../superpoint/test_modeling_superpoint.py | 10 ++-- 5 files changed, 146 insertions(+), 22 deletions(-) diff --git a/docs/source/en/model_doc/superpoint.md b/docs/source/en/model_doc/superpoint.md index b9aab2f1b9..59e451adce 100644 --- a/docs/source/en/model_doc/superpoint.md +++ b/docs/source/en/model_doc/superpoint.md @@ -86,24 +86,32 @@ model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/sup inputs = processor(images, return_tensors="pt") outputs = model(**inputs) +image_sizes = [(image.height, image.width) for image in images] +outputs = processor.post_process_keypoint_detection(outputs, image_sizes) -for i in range(len(images)): - image_mask = outputs.mask[i] - image_indices = torch.nonzero(image_mask).squeeze() - image_keypoints = outputs.keypoints[i][image_indices] - image_scores = outputs.scores[i][image_indices] - image_descriptors = outputs.descriptors[i][image_indices] +for output in outputs: + for keypoints, scores, descriptors in zip(output["keypoints"], output["scores"], output["descriptors"]): + print(f"Keypoints: {keypoints}") + print(f"Scores: {scores}") + print(f"Descriptors: {descriptors}") ``` -You can then print the keypoints on the image to visualize the result : +You can then print the keypoints on the image of your choice to visualize the result: ```python -import cv2 -for keypoint, score in zip(image_keypoints, image_scores): - keypoint_x, keypoint_y = int(keypoint[0].item()), int(keypoint[1].item()) - color = tuple([score.item() * 255] * 3) - image = cv2.circle(image, (keypoint_x, keypoint_y), 2, color) -cv2.imwrite("output_image.png", image) +import matplotlib.pyplot as plt + +plt.axis("off") +plt.imshow(image_1) +plt.scatter( + outputs[0]["keypoints"][:, 0], + outputs[0]["keypoints"][:, 1], + c=outputs[0]["scores"] * 100, + s=outputs[0]["scores"] * 50, + alpha=0.8 +) +plt.savefig(f"output_image.png") ``` +![image/png](https://cdn-uploads.huggingface.co/production/uploads/632885ba1558dac67c440aa8/ZtFmphEhx8tcbEQqOolyE.png) This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille). The original code can be found [here](https://github.com/magicleap/SuperPointPretrainedNetwork). @@ -123,6 +131,7 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h [[autodoc]] SuperPointImageProcessor - preprocess +- post_process_keypoint_detection ## SuperPointForKeypointDetection diff --git a/src/transformers/models/superpoint/image_processing_superpoint.py b/src/transformers/models/superpoint/image_processing_superpoint.py index fbbb717570..65309b1c18 100644 --- a/src/transformers/models/superpoint/image_processing_superpoint.py +++ b/src/transformers/models/superpoint/image_processing_superpoint.py @@ -13,11 +13,11 @@ # limitations under the License. """Image processor class for SuperPoint.""" -from typing import Dict, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import numpy as np -from ... import is_vision_available +from ... import is_torch_available, is_vision_available from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_transforms import resize, to_channel_dimension_format from ...image_utils import ( @@ -32,6 +32,12 @@ from ...image_utils import ( from ...utils import TensorType, logging, requires_backends +if is_torch_available(): + import torch + +if TYPE_CHECKING: + from .modeling_superpoint import SuperPointKeypointDescriptionOutput + if is_vision_available(): import PIL @@ -270,3 +276,52 @@ class SuperPointImageProcessor(BaseImageProcessor): data = {"pixel_values": images} return BatchFeature(data=data, tensor_type=return_tensors) + + def post_process_keypoint_detection( + self, outputs: "SuperPointKeypointDescriptionOutput", target_sizes: Union[TensorType, List[Tuple]] + ) -> List[Dict[str, "torch.Tensor"]]: + """ + Converts the raw output of [`SuperPointForKeypointDetection`] into lists of keypoints, scores and descriptors + with coordinates absolute to the original image sizes. + + Args: + outputs ([`SuperPointKeypointDescriptionOutput`]): + Raw outputs of the model containing keypoints in a relative (x, y) format, with scores and descriptors. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. This must be the original + image size (before any processing). + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in absolute format according + to target_sizes, scores and descriptors for an image in the batch as predicted by the model. + """ + if len(outputs.mask) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask") + + if isinstance(target_sizes, List): + image_sizes = torch.tensor(target_sizes) + else: + if target_sizes.shape[1] != 2: + raise ValueError( + "Each element of target_sizes must contain the size (h, w) of each image of the batch" + ) + image_sizes = target_sizes + + # Flip the image sizes to (width, height) and convert keypoints to absolute coordinates + image_sizes = torch.flip(image_sizes, [1]) + masked_keypoints = outputs.keypoints * image_sizes[:, None] + + # Convert masked_keypoints to int + masked_keypoints = masked_keypoints.to(torch.int32) + + results = [] + for image_mask, keypoints, scores, descriptors in zip( + outputs.mask, masked_keypoints, outputs.scores, outputs.descriptors + ): + indices = torch.nonzero(image_mask).squeeze(1) + keypoints = keypoints[indices] + scores = scores[indices] + descriptors = descriptors[indices] + results.append({"keypoints": keypoints, "scores": scores, "descriptors": descriptors}) + + return results diff --git a/src/transformers/models/superpoint/modeling_superpoint.py b/src/transformers/models/superpoint/modeling_superpoint.py index cfd3dfd86e..1075de299a 100644 --- a/src/transformers/models/superpoint/modeling_superpoint.py +++ b/src/transformers/models/superpoint/modeling_superpoint.py @@ -239,7 +239,10 @@ class SuperPointInterestPointDecoder(nn.Module): return scores def _extract_keypoints(self, scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """Based on their scores, extract the pixels that represent the keypoints that will be used for descriptors computation""" + """ + Based on their scores, extract the pixels that represent the keypoints that will be used for descriptors computation. + The keypoints are in the form of relative (x, y) coordinates. + """ _, height, width = scores.shape # Threshold keypoints by score value @@ -447,7 +450,7 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel): pixel_values = self.extract_one_channel_pixel_values(pixel_values) - batch_size = pixel_values.shape[0] + batch_size, _, height, width = pixel_values.shape encoder_outputs = self.encoder( pixel_values, @@ -485,6 +488,9 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel): descriptors[i, : _descriptors.shape[0]] = _descriptors mask[i, : _scores.shape[0]] = 1 + # Convert to relative coordinates + keypoints = keypoints / torch.tensor([width, height], device=keypoints.device) + hidden_states = encoder_outputs[1] if output_hidden_states else None if not return_dict: return tuple(v for v in [loss, keypoints, scores, descriptors, mask, hidden_states] if v is not None) diff --git a/tests/models/superpoint/test_image_processing_superpoint.py b/tests/models/superpoint/test_image_processing_superpoint.py index 90bbf82d1e..c2eae87200 100644 --- a/tests/models/superpoint/test_image_processing_superpoint.py +++ b/tests/models/superpoint/test_image_processing_superpoint.py @@ -16,7 +16,7 @@ import unittest import numpy as np from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_vision_available +from transformers.utils import is_torch_available, is_vision_available from ...test_image_processing_common import ( ImageProcessingTestMixin, @@ -24,6 +24,11 @@ from ...test_image_processing_common import ( ) +if is_torch_available(): + import torch + + from transformers.models.superpoint.modeling_superpoint import SuperPointKeypointDescriptionOutput + if is_vision_available(): from transformers import SuperPointImageProcessor @@ -70,6 +75,23 @@ class SuperPointImageProcessingTester(unittest.TestCase): torchify=torchify, ) + def prepare_keypoint_detection_output(self, pixel_values): + max_number_keypoints = 50 + batch_size = len(pixel_values) + mask = torch.zeros((batch_size, max_number_keypoints)) + keypoints = torch.zeros((batch_size, max_number_keypoints, 2)) + scores = torch.zeros((batch_size, max_number_keypoints)) + descriptors = torch.zeros((batch_size, max_number_keypoints, 16)) + for i in range(batch_size): + random_number_keypoints = np.random.randint(0, max_number_keypoints) + mask[i, :random_number_keypoints] = 1 + keypoints[i, :random_number_keypoints] = torch.rand((random_number_keypoints, 2)) + scores[i, :random_number_keypoints] = torch.rand((random_number_keypoints,)) + descriptors[i, :random_number_keypoints] = torch.rand((random_number_keypoints, 16)) + return SuperPointKeypointDescriptionOutput( + loss=None, keypoints=keypoints, scores=scores, descriptors=descriptors, mask=mask, hidden_states=None + ) + @require_torch @require_vision @@ -110,3 +132,33 @@ class SuperPointImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase) pre_processed_images = image_processor.preprocess(image_inputs) for image in pre_processed_images["pixel_values"]: self.assertTrue(np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...])) + + @require_torch + def test_post_processing_keypoint_detection(self): + image_processor = self.image_processing_class.from_dict(self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs() + pre_processed_images = image_processor.preprocess(image_inputs, return_tensors="pt") + outputs = self.image_processor_tester.prepare_keypoint_detection_output(**pre_processed_images) + + def check_post_processed_output(post_processed_output, image_size): + for post_processed_output, image_size in zip(post_processed_output, image_size): + self.assertTrue("keypoints" in post_processed_output) + self.assertTrue("descriptors" in post_processed_output) + self.assertTrue("scores" in post_processed_output) + keypoints = post_processed_output["keypoints"] + all_below_image_size = torch.all(keypoints[:, 0] <= image_size[1]) and torch.all( + keypoints[:, 1] <= image_size[0] + ) + all_above_zero = torch.all(keypoints[:, 0] >= 0) and torch.all(keypoints[:, 1] >= 0) + self.assertTrue(all_below_image_size) + self.assertTrue(all_above_zero) + + tuple_image_sizes = [(image.size[0], image.size[1]) for image in image_inputs] + tuple_post_processed_outputs = image_processor.post_process_keypoint_detection(outputs, tuple_image_sizes) + + check_post_processed_output(tuple_post_processed_outputs, tuple_image_sizes) + + tensor_image_sizes = torch.tensor([image.size for image in image_inputs]).flip(1) + tensor_post_processed_outputs = image_processor.post_process_keypoint_detection(outputs, tensor_image_sizes) + + check_post_processed_output(tensor_post_processed_outputs, tensor_image_sizes) diff --git a/tests/models/superpoint/test_modeling_superpoint.py b/tests/models/superpoint/test_modeling_superpoint.py index 25c384a795..8db435502c 100644 --- a/tests/models/superpoint/test_modeling_superpoint.py +++ b/tests/models/superpoint/test_modeling_superpoint.py @@ -260,7 +260,7 @@ class SuperPointModelIntegrationTest(unittest.TestCase): inputs = preprocessor(images=images, return_tensors="pt").to(torch_device) with torch.no_grad(): outputs = model(**inputs) - expected_number_keypoints_image0 = 567 + expected_number_keypoints_image0 = 568 expected_number_keypoints_image1 = 830 expected_max_number_keypoints = max(expected_number_keypoints_image0, expected_number_keypoints_image1) expected_keypoints_shape = torch.Size((len(images), expected_max_number_keypoints, 2)) @@ -275,11 +275,13 @@ class SuperPointModelIntegrationTest(unittest.TestCase): self.assertEqual(outputs.keypoints.shape, expected_keypoints_shape) self.assertEqual(outputs.scores.shape, expected_scores_shape) self.assertEqual(outputs.descriptors.shape, expected_descriptors_shape) - expected_keypoints_image0_values = torch.tensor([[480.0, 9.0], [494.0, 9.0], [489.0, 16.0]]).to(torch_device) + expected_keypoints_image0_values = torch.tensor([[0.75, 0.0188], [0.7719, 0.0188], [0.7641, 0.0333]]).to( + torch_device + ) expected_scores_image0_values = torch.tensor( - [0.0064, 0.0137, 0.0589, 0.0723, 0.5166, 0.0174, 0.1515, 0.2054, 0.0334] + [0.0064, 0.0139, 0.0591, 0.0727, 0.5170, 0.0175, 0.1526, 0.2057, 0.0335] ).to(torch_device) - expected_descriptors_image0_value = torch.tensor(-0.1096).to(torch_device) + expected_descriptors_image0_value = torch.tensor(-0.1095).to(torch_device) predicted_keypoints_image0_values = outputs.keypoints[0, :3] predicted_scores_image0_values = outputs.scores[0, :9] predicted_descriptors_image0_value = outputs.descriptors[0, 0, 0]