🚨🚨🚨 [SuperPoint] Fix keypoint coordinate output and add post processing (#33200)

* feat: Added int conversion and unwrapping

* test: added tests for post_process_keypoint_detection of SuperPointImageProcessor

* docs: changed docs to include post_process_keypoint_detection method and switched from opencv to matplotlib

* test: changed test to not depend on SuperPointModel forward

* test: added missing require_torch decorator

* docs: changed pyplot parameters for the keypoints to be more visible in the example

* tests: changed import torch location to make test_flax and test_tf

* Revert "tests: changed import torch location to make test_flax and test_tf"

This reverts commit 39b32a2f69500bc7af01715fc7beae2260549afe.

* tests: fixed import

* chore: applied suggestions from code review

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* tests: fixed import

* tests: fixed import (bis)

* tests: fixed import (ter)

* feat: added choice of type for target_size and changed tests accordingly

* docs: updated code snippet to reflect the addition of target size type choice in post process method

* tests: fixed imports (...)

* tests: fixed imports (...)

* style: formatting file

* docs: fixed typo from image[0] to image.size[0]

* docs: added output image and fixed some tests

* Update docs/source/en/model_doc/superpoint.md

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* fix: included SuperPointKeypointDescriptionOutput in TYPE_CHECKING if statement and changed tests results to reflect changes to SuperPoint from absolute keypoints coordinates to relative

* docs: changed SuperPoint's docs to print output instead of just accessing

* style: applied make style

* docs: added missing output type and precision in docstring of post_process_keypoint_detection

* perf: deleted loop to perform keypoint conversion in one statement

* fix: moved keypoint conversion at the end of model forward

* docs: changed SuperPointInterestPointDecoder to SuperPointKeypointDecoder class name and added relative (x, y) coordinates information to its method

* fix: changed type hint

* refactor: removed unnecessary brackets

* revert: SuperPointKeypointDecoder to SuperPointInterestPointDecoder

* Update docs/source/en/model_doc/superpoint.md

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

---------

Co-authored-by: Steven Bucaille <steven.bucaille@buawei.com>
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
StevenBucaille
2024-10-29 10:36:03 +01:00
committed by GitHub
parent 655bec2da7
commit a1835195d1
5 changed files with 146 additions and 22 deletions

View File

@@ -86,24 +86,32 @@ model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/sup
inputs = processor(images, return_tensors="pt")
outputs = model(**inputs)
image_sizes = [(image.height, image.width) for image in images]
outputs = processor.post_process_keypoint_detection(outputs, image_sizes)
for i in range(len(images)):
image_mask = outputs.mask[i]
image_indices = torch.nonzero(image_mask).squeeze()
image_keypoints = outputs.keypoints[i][image_indices]
image_scores = outputs.scores[i][image_indices]
image_descriptors = outputs.descriptors[i][image_indices]
for output in outputs:
for keypoints, scores, descriptors in zip(output["keypoints"], output["scores"], output["descriptors"]):
print(f"Keypoints: {keypoints}")
print(f"Scores: {scores}")
print(f"Descriptors: {descriptors}")
```
You can then print the keypoints on the image to visualize the result :
You can then print the keypoints on the image of your choice to visualize the result:
```python
import cv2
for keypoint, score in zip(image_keypoints, image_scores):
keypoint_x, keypoint_y = int(keypoint[0].item()), int(keypoint[1].item())
color = tuple([score.item() * 255] * 3)
image = cv2.circle(image, (keypoint_x, keypoint_y), 2, color)
cv2.imwrite("output_image.png", image)
import matplotlib.pyplot as plt
plt.axis("off")
plt.imshow(image_1)
plt.scatter(
outputs[0]["keypoints"][:, 0],
outputs[0]["keypoints"][:, 1],
c=outputs[0]["scores"] * 100,
s=outputs[0]["scores"] * 50,
alpha=0.8
)
plt.savefig(f"output_image.png")
```
![image/png](https://cdn-uploads.huggingface.co/production/uploads/632885ba1558dac67c440aa8/ZtFmphEhx8tcbEQqOolyE.png)
This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille).
The original code can be found [here](https://github.com/magicleap/SuperPointPretrainedNetwork).
@@ -123,6 +131,7 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] SuperPointImageProcessor
- preprocess
- post_process_keypoint_detection
## SuperPointForKeypointDetection

View File

@@ -13,11 +13,11 @@
# limitations under the License.
"""Image processor class for SuperPoint."""
from typing import Dict, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import numpy as np
from ... import is_vision_available
from ... import is_torch_available, is_vision_available
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import resize, to_channel_dimension_format
from ...image_utils import (
@@ -32,6 +32,12 @@ from ...image_utils import (
from ...utils import TensorType, logging, requires_backends
if is_torch_available():
import torch
if TYPE_CHECKING:
from .modeling_superpoint import SuperPointKeypointDescriptionOutput
if is_vision_available():
import PIL
@@ -270,3 +276,52 @@ class SuperPointImageProcessor(BaseImageProcessor):
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
def post_process_keypoint_detection(
self, outputs: "SuperPointKeypointDescriptionOutput", target_sizes: Union[TensorType, List[Tuple]]
) -> List[Dict[str, "torch.Tensor"]]:
"""
Converts the raw output of [`SuperPointForKeypointDetection`] into lists of keypoints, scores and descriptors
with coordinates absolute to the original image sizes.
Args:
outputs ([`SuperPointKeypointDescriptionOutput`]):
Raw outputs of the model containing keypoints in a relative (x, y) format, with scores and descriptors.
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
`(height, width)` of each image in the batch. This must be the original
image size (before any processing).
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in absolute format according
to target_sizes, scores and descriptors for an image in the batch as predicted by the model.
"""
if len(outputs.mask) != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask")
if isinstance(target_sizes, List):
image_sizes = torch.tensor(target_sizes)
else:
if target_sizes.shape[1] != 2:
raise ValueError(
"Each element of target_sizes must contain the size (h, w) of each image of the batch"
)
image_sizes = target_sizes
# Flip the image sizes to (width, height) and convert keypoints to absolute coordinates
image_sizes = torch.flip(image_sizes, [1])
masked_keypoints = outputs.keypoints * image_sizes[:, None]
# Convert masked_keypoints to int
masked_keypoints = masked_keypoints.to(torch.int32)
results = []
for image_mask, keypoints, scores, descriptors in zip(
outputs.mask, masked_keypoints, outputs.scores, outputs.descriptors
):
indices = torch.nonzero(image_mask).squeeze(1)
keypoints = keypoints[indices]
scores = scores[indices]
descriptors = descriptors[indices]
results.append({"keypoints": keypoints, "scores": scores, "descriptors": descriptors})
return results

View File

@@ -239,7 +239,10 @@ class SuperPointInterestPointDecoder(nn.Module):
return scores
def _extract_keypoints(self, scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Based on their scores, extract the pixels that represent the keypoints that will be used for descriptors computation"""
"""
Based on their scores, extract the pixels that represent the keypoints that will be used for descriptors computation.
The keypoints are in the form of relative (x, y) coordinates.
"""
_, height, width = scores.shape
# Threshold keypoints by score value
@@ -447,7 +450,7 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
pixel_values = self.extract_one_channel_pixel_values(pixel_values)
batch_size = pixel_values.shape[0]
batch_size, _, height, width = pixel_values.shape
encoder_outputs = self.encoder(
pixel_values,
@@ -485,6 +488,9 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
descriptors[i, : _descriptors.shape[0]] = _descriptors
mask[i, : _scores.shape[0]] = 1
# Convert to relative coordinates
keypoints = keypoints / torch.tensor([width, height], device=keypoints.device)
hidden_states = encoder_outputs[1] if output_hidden_states else None
if not return_dict:
return tuple(v for v in [loss, keypoints, scores, descriptors, mask, hidden_states] if v is not None)

View File

@@ -16,7 +16,7 @@ import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from transformers.utils import is_torch_available, is_vision_available
from ...test_image_processing_common import (
ImageProcessingTestMixin,
@@ -24,6 +24,11 @@ from ...test_image_processing_common import (
)
if is_torch_available():
import torch
from transformers.models.superpoint.modeling_superpoint import SuperPointKeypointDescriptionOutput
if is_vision_available():
from transformers import SuperPointImageProcessor
@@ -70,6 +75,23 @@ class SuperPointImageProcessingTester(unittest.TestCase):
torchify=torchify,
)
def prepare_keypoint_detection_output(self, pixel_values):
max_number_keypoints = 50
batch_size = len(pixel_values)
mask = torch.zeros((batch_size, max_number_keypoints))
keypoints = torch.zeros((batch_size, max_number_keypoints, 2))
scores = torch.zeros((batch_size, max_number_keypoints))
descriptors = torch.zeros((batch_size, max_number_keypoints, 16))
for i in range(batch_size):
random_number_keypoints = np.random.randint(0, max_number_keypoints)
mask[i, :random_number_keypoints] = 1
keypoints[i, :random_number_keypoints] = torch.rand((random_number_keypoints, 2))
scores[i, :random_number_keypoints] = torch.rand((random_number_keypoints,))
descriptors[i, :random_number_keypoints] = torch.rand((random_number_keypoints, 16))
return SuperPointKeypointDescriptionOutput(
loss=None, keypoints=keypoints, scores=scores, descriptors=descriptors, mask=mask, hidden_states=None
)
@require_torch
@require_vision
@@ -110,3 +132,33 @@ class SuperPointImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
pre_processed_images = image_processor.preprocess(image_inputs)
for image in pre_processed_images["pixel_values"]:
self.assertTrue(np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...]))
@require_torch
def test_post_processing_keypoint_detection(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs()
pre_processed_images = image_processor.preprocess(image_inputs, return_tensors="pt")
outputs = self.image_processor_tester.prepare_keypoint_detection_output(**pre_processed_images)
def check_post_processed_output(post_processed_output, image_size):
for post_processed_output, image_size in zip(post_processed_output, image_size):
self.assertTrue("keypoints" in post_processed_output)
self.assertTrue("descriptors" in post_processed_output)
self.assertTrue("scores" in post_processed_output)
keypoints = post_processed_output["keypoints"]
all_below_image_size = torch.all(keypoints[:, 0] <= image_size[1]) and torch.all(
keypoints[:, 1] <= image_size[0]
)
all_above_zero = torch.all(keypoints[:, 0] >= 0) and torch.all(keypoints[:, 1] >= 0)
self.assertTrue(all_below_image_size)
self.assertTrue(all_above_zero)
tuple_image_sizes = [(image.size[0], image.size[1]) for image in image_inputs]
tuple_post_processed_outputs = image_processor.post_process_keypoint_detection(outputs, tuple_image_sizes)
check_post_processed_output(tuple_post_processed_outputs, tuple_image_sizes)
tensor_image_sizes = torch.tensor([image.size for image in image_inputs]).flip(1)
tensor_post_processed_outputs = image_processor.post_process_keypoint_detection(outputs, tensor_image_sizes)
check_post_processed_output(tensor_post_processed_outputs, tensor_image_sizes)

View File

@@ -260,7 +260,7 @@ class SuperPointModelIntegrationTest(unittest.TestCase):
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
with torch.no_grad():
outputs = model(**inputs)
expected_number_keypoints_image0 = 567
expected_number_keypoints_image0 = 568
expected_number_keypoints_image1 = 830
expected_max_number_keypoints = max(expected_number_keypoints_image0, expected_number_keypoints_image1)
expected_keypoints_shape = torch.Size((len(images), expected_max_number_keypoints, 2))
@@ -275,11 +275,13 @@ class SuperPointModelIntegrationTest(unittest.TestCase):
self.assertEqual(outputs.keypoints.shape, expected_keypoints_shape)
self.assertEqual(outputs.scores.shape, expected_scores_shape)
self.assertEqual(outputs.descriptors.shape, expected_descriptors_shape)
expected_keypoints_image0_values = torch.tensor([[480.0, 9.0], [494.0, 9.0], [489.0, 16.0]]).to(torch_device)
expected_keypoints_image0_values = torch.tensor([[0.75, 0.0188], [0.7719, 0.0188], [0.7641, 0.0333]]).to(
torch_device
)
expected_scores_image0_values = torch.tensor(
[0.0064, 0.0137, 0.0589, 0.0723, 0.5166, 0.0174, 0.1515, 0.2054, 0.0334]
[0.0064, 0.0139, 0.0591, 0.0727, 0.5170, 0.0175, 0.1526, 0.2057, 0.0335]
).to(torch_device)
expected_descriptors_image0_value = torch.tensor(-0.1096).to(torch_device)
expected_descriptors_image0_value = torch.tensor(-0.1095).to(torch_device)
predicted_keypoints_image0_values = outputs.keypoints[0, :3]
predicted_scores_image0_values = outputs.scores[0, :9]
predicted_descriptors_image0_value = outputs.descriptors[0, 0, 0]