🚨🚨🚨 [SuperPoint] Fix keypoint coordinate output and add post processing (#33200)
* feat: Added int conversion and unwrapping * test: added tests for post_process_keypoint_detection of SuperPointImageProcessor * docs: changed docs to include post_process_keypoint_detection method and switched from opencv to matplotlib * test: changed test to not depend on SuperPointModel forward * test: added missing require_torch decorator * docs: changed pyplot parameters for the keypoints to be more visible in the example * tests: changed import torch location to make test_flax and test_tf * Revert "tests: changed import torch location to make test_flax and test_tf" This reverts commit 39b32a2f69500bc7af01715fc7beae2260549afe. * tests: fixed import * chore: applied suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * tests: fixed import * tests: fixed import (bis) * tests: fixed import (ter) * feat: added choice of type for target_size and changed tests accordingly * docs: updated code snippet to reflect the addition of target size type choice in post process method * tests: fixed imports (...) * tests: fixed imports (...) * style: formatting file * docs: fixed typo from image[0] to image.size[0] * docs: added output image and fixed some tests * Update docs/source/en/model_doc/superpoint.md Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * fix: included SuperPointKeypointDescriptionOutput in TYPE_CHECKING if statement and changed tests results to reflect changes to SuperPoint from absolute keypoints coordinates to relative * docs: changed SuperPoint's docs to print output instead of just accessing * style: applied make style * docs: added missing output type and precision in docstring of post_process_keypoint_detection * perf: deleted loop to perform keypoint conversion in one statement * fix: moved keypoint conversion at the end of model forward * docs: changed SuperPointInterestPointDecoder to SuperPointKeypointDecoder class name and added relative (x, y) coordinates information to its method * fix: changed type hint * refactor: removed unnecessary brackets * revert: SuperPointKeypointDecoder to SuperPointInterestPointDecoder * Update docs/source/en/model_doc/superpoint.md Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> --------- Co-authored-by: Steven Bucaille <steven.bucaille@buawei.com> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
@@ -86,24 +86,32 @@ model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/sup
|
||||
|
||||
inputs = processor(images, return_tensors="pt")
|
||||
outputs = model(**inputs)
|
||||
image_sizes = [(image.height, image.width) for image in images]
|
||||
outputs = processor.post_process_keypoint_detection(outputs, image_sizes)
|
||||
|
||||
for i in range(len(images)):
|
||||
image_mask = outputs.mask[i]
|
||||
image_indices = torch.nonzero(image_mask).squeeze()
|
||||
image_keypoints = outputs.keypoints[i][image_indices]
|
||||
image_scores = outputs.scores[i][image_indices]
|
||||
image_descriptors = outputs.descriptors[i][image_indices]
|
||||
for output in outputs:
|
||||
for keypoints, scores, descriptors in zip(output["keypoints"], output["scores"], output["descriptors"]):
|
||||
print(f"Keypoints: {keypoints}")
|
||||
print(f"Scores: {scores}")
|
||||
print(f"Descriptors: {descriptors}")
|
||||
```
|
||||
|
||||
You can then print the keypoints on the image to visualize the result :
|
||||
You can then print the keypoints on the image of your choice to visualize the result:
|
||||
```python
|
||||
import cv2
|
||||
for keypoint, score in zip(image_keypoints, image_scores):
|
||||
keypoint_x, keypoint_y = int(keypoint[0].item()), int(keypoint[1].item())
|
||||
color = tuple([score.item() * 255] * 3)
|
||||
image = cv2.circle(image, (keypoint_x, keypoint_y), 2, color)
|
||||
cv2.imwrite("output_image.png", image)
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
plt.axis("off")
|
||||
plt.imshow(image_1)
|
||||
plt.scatter(
|
||||
outputs[0]["keypoints"][:, 0],
|
||||
outputs[0]["keypoints"][:, 1],
|
||||
c=outputs[0]["scores"] * 100,
|
||||
s=outputs[0]["scores"] * 50,
|
||||
alpha=0.8
|
||||
)
|
||||
plt.savefig(f"output_image.png")
|
||||
```
|
||||

|
||||
|
||||
This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille).
|
||||
The original code can be found [here](https://github.com/magicleap/SuperPointPretrainedNetwork).
|
||||
@@ -123,6 +131,7 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
||||
[[autodoc]] SuperPointImageProcessor
|
||||
|
||||
- preprocess
|
||||
- post_process_keypoint_detection
|
||||
|
||||
## SuperPointForKeypointDetection
|
||||
|
||||
|
||||
@@ -13,11 +13,11 @@
|
||||
# limitations under the License.
|
||||
"""Image processor class for SuperPoint."""
|
||||
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ... import is_vision_available
|
||||
from ... import is_torch_available, is_vision_available
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import resize, to_channel_dimension_format
|
||||
from ...image_utils import (
|
||||
@@ -32,6 +32,12 @@ from ...image_utils import (
|
||||
from ...utils import TensorType, logging, requires_backends
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .modeling_superpoint import SuperPointKeypointDescriptionOutput
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
|
||||
@@ -270,3 +276,52 @@ class SuperPointImageProcessor(BaseImageProcessor):
|
||||
data = {"pixel_values": images}
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
def post_process_keypoint_detection(
|
||||
self, outputs: "SuperPointKeypointDescriptionOutput", target_sizes: Union[TensorType, List[Tuple]]
|
||||
) -> List[Dict[str, "torch.Tensor"]]:
|
||||
"""
|
||||
Converts the raw output of [`SuperPointForKeypointDetection`] into lists of keypoints, scores and descriptors
|
||||
with coordinates absolute to the original image sizes.
|
||||
|
||||
Args:
|
||||
outputs ([`SuperPointKeypointDescriptionOutput`]):
|
||||
Raw outputs of the model containing keypoints in a relative (x, y) format, with scores and descriptors.
|
||||
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`):
|
||||
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
||||
`(height, width)` of each image in the batch. This must be the original
|
||||
image size (before any processing).
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in absolute format according
|
||||
to target_sizes, scores and descriptors for an image in the batch as predicted by the model.
|
||||
"""
|
||||
if len(outputs.mask) != len(target_sizes):
|
||||
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask")
|
||||
|
||||
if isinstance(target_sizes, List):
|
||||
image_sizes = torch.tensor(target_sizes)
|
||||
else:
|
||||
if target_sizes.shape[1] != 2:
|
||||
raise ValueError(
|
||||
"Each element of target_sizes must contain the size (h, w) of each image of the batch"
|
||||
)
|
||||
image_sizes = target_sizes
|
||||
|
||||
# Flip the image sizes to (width, height) and convert keypoints to absolute coordinates
|
||||
image_sizes = torch.flip(image_sizes, [1])
|
||||
masked_keypoints = outputs.keypoints * image_sizes[:, None]
|
||||
|
||||
# Convert masked_keypoints to int
|
||||
masked_keypoints = masked_keypoints.to(torch.int32)
|
||||
|
||||
results = []
|
||||
for image_mask, keypoints, scores, descriptors in zip(
|
||||
outputs.mask, masked_keypoints, outputs.scores, outputs.descriptors
|
||||
):
|
||||
indices = torch.nonzero(image_mask).squeeze(1)
|
||||
keypoints = keypoints[indices]
|
||||
scores = scores[indices]
|
||||
descriptors = descriptors[indices]
|
||||
results.append({"keypoints": keypoints, "scores": scores, "descriptors": descriptors})
|
||||
|
||||
return results
|
||||
|
||||
@@ -239,7 +239,10 @@ class SuperPointInterestPointDecoder(nn.Module):
|
||||
return scores
|
||||
|
||||
def _extract_keypoints(self, scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Based on their scores, extract the pixels that represent the keypoints that will be used for descriptors computation"""
|
||||
"""
|
||||
Based on their scores, extract the pixels that represent the keypoints that will be used for descriptors computation.
|
||||
The keypoints are in the form of relative (x, y) coordinates.
|
||||
"""
|
||||
_, height, width = scores.shape
|
||||
|
||||
# Threshold keypoints by score value
|
||||
@@ -447,7 +450,7 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
|
||||
|
||||
pixel_values = self.extract_one_channel_pixel_values(pixel_values)
|
||||
|
||||
batch_size = pixel_values.shape[0]
|
||||
batch_size, _, height, width = pixel_values.shape
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
pixel_values,
|
||||
@@ -485,6 +488,9 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
|
||||
descriptors[i, : _descriptors.shape[0]] = _descriptors
|
||||
mask[i, : _scores.shape[0]] = 1
|
||||
|
||||
# Convert to relative coordinates
|
||||
keypoints = keypoints / torch.tensor([width, height], device=keypoints.device)
|
||||
|
||||
hidden_states = encoder_outputs[1] if output_hidden_states else None
|
||||
if not return_dict:
|
||||
return tuple(v for v in [loss, keypoints, scores, descriptors, mask, hidden_states] if v is not None)
|
||||
|
||||
@@ -16,7 +16,7 @@ import unittest
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import (
|
||||
ImageProcessingTestMixin,
|
||||
@@ -24,6 +24,11 @@ from ...test_image_processing_common import (
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.models.superpoint.modeling_superpoint import SuperPointKeypointDescriptionOutput
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import SuperPointImageProcessor
|
||||
|
||||
@@ -70,6 +75,23 @@ class SuperPointImageProcessingTester(unittest.TestCase):
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
def prepare_keypoint_detection_output(self, pixel_values):
|
||||
max_number_keypoints = 50
|
||||
batch_size = len(pixel_values)
|
||||
mask = torch.zeros((batch_size, max_number_keypoints))
|
||||
keypoints = torch.zeros((batch_size, max_number_keypoints, 2))
|
||||
scores = torch.zeros((batch_size, max_number_keypoints))
|
||||
descriptors = torch.zeros((batch_size, max_number_keypoints, 16))
|
||||
for i in range(batch_size):
|
||||
random_number_keypoints = np.random.randint(0, max_number_keypoints)
|
||||
mask[i, :random_number_keypoints] = 1
|
||||
keypoints[i, :random_number_keypoints] = torch.rand((random_number_keypoints, 2))
|
||||
scores[i, :random_number_keypoints] = torch.rand((random_number_keypoints,))
|
||||
descriptors[i, :random_number_keypoints] = torch.rand((random_number_keypoints, 16))
|
||||
return SuperPointKeypointDescriptionOutput(
|
||||
loss=None, keypoints=keypoints, scores=scores, descriptors=descriptors, mask=mask, hidden_states=None
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
@@ -110,3 +132,33 @@ class SuperPointImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
|
||||
pre_processed_images = image_processor.preprocess(image_inputs)
|
||||
for image in pre_processed_images["pixel_values"]:
|
||||
self.assertTrue(np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...]))
|
||||
|
||||
@require_torch
|
||||
def test_post_processing_keypoint_detection(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs()
|
||||
pre_processed_images = image_processor.preprocess(image_inputs, return_tensors="pt")
|
||||
outputs = self.image_processor_tester.prepare_keypoint_detection_output(**pre_processed_images)
|
||||
|
||||
def check_post_processed_output(post_processed_output, image_size):
|
||||
for post_processed_output, image_size in zip(post_processed_output, image_size):
|
||||
self.assertTrue("keypoints" in post_processed_output)
|
||||
self.assertTrue("descriptors" in post_processed_output)
|
||||
self.assertTrue("scores" in post_processed_output)
|
||||
keypoints = post_processed_output["keypoints"]
|
||||
all_below_image_size = torch.all(keypoints[:, 0] <= image_size[1]) and torch.all(
|
||||
keypoints[:, 1] <= image_size[0]
|
||||
)
|
||||
all_above_zero = torch.all(keypoints[:, 0] >= 0) and torch.all(keypoints[:, 1] >= 0)
|
||||
self.assertTrue(all_below_image_size)
|
||||
self.assertTrue(all_above_zero)
|
||||
|
||||
tuple_image_sizes = [(image.size[0], image.size[1]) for image in image_inputs]
|
||||
tuple_post_processed_outputs = image_processor.post_process_keypoint_detection(outputs, tuple_image_sizes)
|
||||
|
||||
check_post_processed_output(tuple_post_processed_outputs, tuple_image_sizes)
|
||||
|
||||
tensor_image_sizes = torch.tensor([image.size for image in image_inputs]).flip(1)
|
||||
tensor_post_processed_outputs = image_processor.post_process_keypoint_detection(outputs, tensor_image_sizes)
|
||||
|
||||
check_post_processed_output(tensor_post_processed_outputs, tensor_image_sizes)
|
||||
|
||||
@@ -260,7 +260,7 @@ class SuperPointModelIntegrationTest(unittest.TestCase):
|
||||
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
expected_number_keypoints_image0 = 567
|
||||
expected_number_keypoints_image0 = 568
|
||||
expected_number_keypoints_image1 = 830
|
||||
expected_max_number_keypoints = max(expected_number_keypoints_image0, expected_number_keypoints_image1)
|
||||
expected_keypoints_shape = torch.Size((len(images), expected_max_number_keypoints, 2))
|
||||
@@ -275,11 +275,13 @@ class SuperPointModelIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(outputs.keypoints.shape, expected_keypoints_shape)
|
||||
self.assertEqual(outputs.scores.shape, expected_scores_shape)
|
||||
self.assertEqual(outputs.descriptors.shape, expected_descriptors_shape)
|
||||
expected_keypoints_image0_values = torch.tensor([[480.0, 9.0], [494.0, 9.0], [489.0, 16.0]]).to(torch_device)
|
||||
expected_keypoints_image0_values = torch.tensor([[0.75, 0.0188], [0.7719, 0.0188], [0.7641, 0.0333]]).to(
|
||||
torch_device
|
||||
)
|
||||
expected_scores_image0_values = torch.tensor(
|
||||
[0.0064, 0.0137, 0.0589, 0.0723, 0.5166, 0.0174, 0.1515, 0.2054, 0.0334]
|
||||
[0.0064, 0.0139, 0.0591, 0.0727, 0.5170, 0.0175, 0.1526, 0.2057, 0.0335]
|
||||
).to(torch_device)
|
||||
expected_descriptors_image0_value = torch.tensor(-0.1096).to(torch_device)
|
||||
expected_descriptors_image0_value = torch.tensor(-0.1095).to(torch_device)
|
||||
predicted_keypoints_image0_values = outputs.keypoints[0, :3]
|
||||
predicted_scores_image0_values = outputs.scores[0, :9]
|
||||
predicted_descriptors_image0_value = outputs.descriptors[0, 0, 0]
|
||||
|
||||
Reference in New Issue
Block a user