🚨🚨🚨 [SuperPoint] Fix keypoint coordinate output and add post processing (#33200)
* feat: Added int conversion and unwrapping * test: added tests for post_process_keypoint_detection of SuperPointImageProcessor * docs: changed docs to include post_process_keypoint_detection method and switched from opencv to matplotlib * test: changed test to not depend on SuperPointModel forward * test: added missing require_torch decorator * docs: changed pyplot parameters for the keypoints to be more visible in the example * tests: changed import torch location to make test_flax and test_tf * Revert "tests: changed import torch location to make test_flax and test_tf" This reverts commit 39b32a2f69500bc7af01715fc7beae2260549afe. * tests: fixed import * chore: applied suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * tests: fixed import * tests: fixed import (bis) * tests: fixed import (ter) * feat: added choice of type for target_size and changed tests accordingly * docs: updated code snippet to reflect the addition of target size type choice in post process method * tests: fixed imports (...) * tests: fixed imports (...) * style: formatting file * docs: fixed typo from image[0] to image.size[0] * docs: added output image and fixed some tests * Update docs/source/en/model_doc/superpoint.md Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * fix: included SuperPointKeypointDescriptionOutput in TYPE_CHECKING if statement and changed tests results to reflect changes to SuperPoint from absolute keypoints coordinates to relative * docs: changed SuperPoint's docs to print output instead of just accessing * style: applied make style * docs: added missing output type and precision in docstring of post_process_keypoint_detection * perf: deleted loop to perform keypoint conversion in one statement * fix: moved keypoint conversion at the end of model forward * docs: changed SuperPointInterestPointDecoder to SuperPointKeypointDecoder class name and added relative (x, y) coordinates information to its method * fix: changed type hint * refactor: removed unnecessary brackets * revert: SuperPointKeypointDecoder to SuperPointInterestPointDecoder * Update docs/source/en/model_doc/superpoint.md Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> --------- Co-authored-by: Steven Bucaille <steven.bucaille@buawei.com> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
@@ -86,24 +86,32 @@ model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/sup
|
|||||||
|
|
||||||
inputs = processor(images, return_tensors="pt")
|
inputs = processor(images, return_tensors="pt")
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
image_sizes = [(image.height, image.width) for image in images]
|
||||||
|
outputs = processor.post_process_keypoint_detection(outputs, image_sizes)
|
||||||
|
|
||||||
for i in range(len(images)):
|
for output in outputs:
|
||||||
image_mask = outputs.mask[i]
|
for keypoints, scores, descriptors in zip(output["keypoints"], output["scores"], output["descriptors"]):
|
||||||
image_indices = torch.nonzero(image_mask).squeeze()
|
print(f"Keypoints: {keypoints}")
|
||||||
image_keypoints = outputs.keypoints[i][image_indices]
|
print(f"Scores: {scores}")
|
||||||
image_scores = outputs.scores[i][image_indices]
|
print(f"Descriptors: {descriptors}")
|
||||||
image_descriptors = outputs.descriptors[i][image_indices]
|
|
||||||
```
|
```
|
||||||
|
|
||||||
You can then print the keypoints on the image to visualize the result :
|
You can then print the keypoints on the image of your choice to visualize the result:
|
||||||
```python
|
```python
|
||||||
import cv2
|
import matplotlib.pyplot as plt
|
||||||
for keypoint, score in zip(image_keypoints, image_scores):
|
|
||||||
keypoint_x, keypoint_y = int(keypoint[0].item()), int(keypoint[1].item())
|
plt.axis("off")
|
||||||
color = tuple([score.item() * 255] * 3)
|
plt.imshow(image_1)
|
||||||
image = cv2.circle(image, (keypoint_x, keypoint_y), 2, color)
|
plt.scatter(
|
||||||
cv2.imwrite("output_image.png", image)
|
outputs[0]["keypoints"][:, 0],
|
||||||
|
outputs[0]["keypoints"][:, 1],
|
||||||
|
c=outputs[0]["scores"] * 100,
|
||||||
|
s=outputs[0]["scores"] * 50,
|
||||||
|
alpha=0.8
|
||||||
|
)
|
||||||
|
plt.savefig(f"output_image.png")
|
||||||
```
|
```
|
||||||
|

|
||||||
|
|
||||||
This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille).
|
This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille).
|
||||||
The original code can be found [here](https://github.com/magicleap/SuperPointPretrainedNetwork).
|
The original code can be found [here](https://github.com/magicleap/SuperPointPretrainedNetwork).
|
||||||
@@ -123,6 +131,7 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
|||||||
[[autodoc]] SuperPointImageProcessor
|
[[autodoc]] SuperPointImageProcessor
|
||||||
|
|
||||||
- preprocess
|
- preprocess
|
||||||
|
- post_process_keypoint_detection
|
||||||
|
|
||||||
## SuperPointForKeypointDetection
|
## SuperPointForKeypointDetection
|
||||||
|
|
||||||
|
|||||||
@@ -13,11 +13,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Image processor class for SuperPoint."""
|
"""Image processor class for SuperPoint."""
|
||||||
|
|
||||||
from typing import Dict, Optional, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ... import is_vision_available
|
from ... import is_torch_available, is_vision_available
|
||||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
from ...image_transforms import resize, to_channel_dimension_format
|
from ...image_transforms import resize, to_channel_dimension_format
|
||||||
from ...image_utils import (
|
from ...image_utils import (
|
||||||
@@ -32,6 +32,12 @@ from ...image_utils import (
|
|||||||
from ...utils import TensorType, logging, requires_backends
|
from ...utils import TensorType, logging, requires_backends
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .modeling_superpoint import SuperPointKeypointDescriptionOutput
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
import PIL
|
import PIL
|
||||||
|
|
||||||
@@ -270,3 +276,52 @@ class SuperPointImageProcessor(BaseImageProcessor):
|
|||||||
data = {"pixel_values": images}
|
data = {"pixel_values": images}
|
||||||
|
|
||||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def post_process_keypoint_detection(
|
||||||
|
self, outputs: "SuperPointKeypointDescriptionOutput", target_sizes: Union[TensorType, List[Tuple]]
|
||||||
|
) -> List[Dict[str, "torch.Tensor"]]:
|
||||||
|
"""
|
||||||
|
Converts the raw output of [`SuperPointForKeypointDetection`] into lists of keypoints, scores and descriptors
|
||||||
|
with coordinates absolute to the original image sizes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs ([`SuperPointKeypointDescriptionOutput`]):
|
||||||
|
Raw outputs of the model containing keypoints in a relative (x, y) format, with scores and descriptors.
|
||||||
|
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`):
|
||||||
|
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
||||||
|
`(height, width)` of each image in the batch. This must be the original
|
||||||
|
image size (before any processing).
|
||||||
|
Returns:
|
||||||
|
`List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in absolute format according
|
||||||
|
to target_sizes, scores and descriptors for an image in the batch as predicted by the model.
|
||||||
|
"""
|
||||||
|
if len(outputs.mask) != len(target_sizes):
|
||||||
|
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask")
|
||||||
|
|
||||||
|
if isinstance(target_sizes, List):
|
||||||
|
image_sizes = torch.tensor(target_sizes)
|
||||||
|
else:
|
||||||
|
if target_sizes.shape[1] != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"Each element of target_sizes must contain the size (h, w) of each image of the batch"
|
||||||
|
)
|
||||||
|
image_sizes = target_sizes
|
||||||
|
|
||||||
|
# Flip the image sizes to (width, height) and convert keypoints to absolute coordinates
|
||||||
|
image_sizes = torch.flip(image_sizes, [1])
|
||||||
|
masked_keypoints = outputs.keypoints * image_sizes[:, None]
|
||||||
|
|
||||||
|
# Convert masked_keypoints to int
|
||||||
|
masked_keypoints = masked_keypoints.to(torch.int32)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for image_mask, keypoints, scores, descriptors in zip(
|
||||||
|
outputs.mask, masked_keypoints, outputs.scores, outputs.descriptors
|
||||||
|
):
|
||||||
|
indices = torch.nonzero(image_mask).squeeze(1)
|
||||||
|
keypoints = keypoints[indices]
|
||||||
|
scores = scores[indices]
|
||||||
|
descriptors = descriptors[indices]
|
||||||
|
results.append({"keypoints": keypoints, "scores": scores, "descriptors": descriptors})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|||||||
@@ -239,7 +239,10 @@ class SuperPointInterestPointDecoder(nn.Module):
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
def _extract_keypoints(self, scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def _extract_keypoints(self, scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Based on their scores, extract the pixels that represent the keypoints that will be used for descriptors computation"""
|
"""
|
||||||
|
Based on their scores, extract the pixels that represent the keypoints that will be used for descriptors computation.
|
||||||
|
The keypoints are in the form of relative (x, y) coordinates.
|
||||||
|
"""
|
||||||
_, height, width = scores.shape
|
_, height, width = scores.shape
|
||||||
|
|
||||||
# Threshold keypoints by score value
|
# Threshold keypoints by score value
|
||||||
@@ -447,7 +450,7 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
|
|||||||
|
|
||||||
pixel_values = self.extract_one_channel_pixel_values(pixel_values)
|
pixel_values = self.extract_one_channel_pixel_values(pixel_values)
|
||||||
|
|
||||||
batch_size = pixel_values.shape[0]
|
batch_size, _, height, width = pixel_values.shape
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
pixel_values,
|
pixel_values,
|
||||||
@@ -485,6 +488,9 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
|
|||||||
descriptors[i, : _descriptors.shape[0]] = _descriptors
|
descriptors[i, : _descriptors.shape[0]] = _descriptors
|
||||||
mask[i, : _scores.shape[0]] = 1
|
mask[i, : _scores.shape[0]] = 1
|
||||||
|
|
||||||
|
# Convert to relative coordinates
|
||||||
|
keypoints = keypoints / torch.tensor([width, height], device=keypoints.device)
|
||||||
|
|
||||||
hidden_states = encoder_outputs[1] if output_hidden_states else None
|
hidden_states = encoder_outputs[1] if output_hidden_states else None
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [loss, keypoints, scores, descriptors, mask, hidden_states] if v is not None)
|
return tuple(v for v in [loss, keypoints, scores, descriptors, mask, hidden_states] if v is not None)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers.testing_utils import require_torch, require_vision
|
from transformers.testing_utils import require_torch, require_vision
|
||||||
from transformers.utils import is_vision_available
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
from ...test_image_processing_common import (
|
from ...test_image_processing_common import (
|
||||||
ImageProcessingTestMixin,
|
ImageProcessingTestMixin,
|
||||||
@@ -24,6 +24,11 @@ from ...test_image_processing_common import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers.models.superpoint.modeling_superpoint import SuperPointKeypointDescriptionOutput
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from transformers import SuperPointImageProcessor
|
from transformers import SuperPointImageProcessor
|
||||||
|
|
||||||
@@ -70,6 +75,23 @@ class SuperPointImageProcessingTester(unittest.TestCase):
|
|||||||
torchify=torchify,
|
torchify=torchify,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def prepare_keypoint_detection_output(self, pixel_values):
|
||||||
|
max_number_keypoints = 50
|
||||||
|
batch_size = len(pixel_values)
|
||||||
|
mask = torch.zeros((batch_size, max_number_keypoints))
|
||||||
|
keypoints = torch.zeros((batch_size, max_number_keypoints, 2))
|
||||||
|
scores = torch.zeros((batch_size, max_number_keypoints))
|
||||||
|
descriptors = torch.zeros((batch_size, max_number_keypoints, 16))
|
||||||
|
for i in range(batch_size):
|
||||||
|
random_number_keypoints = np.random.randint(0, max_number_keypoints)
|
||||||
|
mask[i, :random_number_keypoints] = 1
|
||||||
|
keypoints[i, :random_number_keypoints] = torch.rand((random_number_keypoints, 2))
|
||||||
|
scores[i, :random_number_keypoints] = torch.rand((random_number_keypoints,))
|
||||||
|
descriptors[i, :random_number_keypoints] = torch.rand((random_number_keypoints, 16))
|
||||||
|
return SuperPointKeypointDescriptionOutput(
|
||||||
|
loss=None, keypoints=keypoints, scores=scores, descriptors=descriptors, mask=mask, hidden_states=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_vision
|
@require_vision
|
||||||
@@ -110,3 +132,33 @@ class SuperPointImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
|
|||||||
pre_processed_images = image_processor.preprocess(image_inputs)
|
pre_processed_images = image_processor.preprocess(image_inputs)
|
||||||
for image in pre_processed_images["pixel_values"]:
|
for image in pre_processed_images["pixel_values"]:
|
||||||
self.assertTrue(np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...]))
|
self.assertTrue(np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...]))
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_post_processing_keypoint_detection(self):
|
||||||
|
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||||
|
image_inputs = self.image_processor_tester.prepare_image_inputs()
|
||||||
|
pre_processed_images = image_processor.preprocess(image_inputs, return_tensors="pt")
|
||||||
|
outputs = self.image_processor_tester.prepare_keypoint_detection_output(**pre_processed_images)
|
||||||
|
|
||||||
|
def check_post_processed_output(post_processed_output, image_size):
|
||||||
|
for post_processed_output, image_size in zip(post_processed_output, image_size):
|
||||||
|
self.assertTrue("keypoints" in post_processed_output)
|
||||||
|
self.assertTrue("descriptors" in post_processed_output)
|
||||||
|
self.assertTrue("scores" in post_processed_output)
|
||||||
|
keypoints = post_processed_output["keypoints"]
|
||||||
|
all_below_image_size = torch.all(keypoints[:, 0] <= image_size[1]) and torch.all(
|
||||||
|
keypoints[:, 1] <= image_size[0]
|
||||||
|
)
|
||||||
|
all_above_zero = torch.all(keypoints[:, 0] >= 0) and torch.all(keypoints[:, 1] >= 0)
|
||||||
|
self.assertTrue(all_below_image_size)
|
||||||
|
self.assertTrue(all_above_zero)
|
||||||
|
|
||||||
|
tuple_image_sizes = [(image.size[0], image.size[1]) for image in image_inputs]
|
||||||
|
tuple_post_processed_outputs = image_processor.post_process_keypoint_detection(outputs, tuple_image_sizes)
|
||||||
|
|
||||||
|
check_post_processed_output(tuple_post_processed_outputs, tuple_image_sizes)
|
||||||
|
|
||||||
|
tensor_image_sizes = torch.tensor([image.size for image in image_inputs]).flip(1)
|
||||||
|
tensor_post_processed_outputs = image_processor.post_process_keypoint_detection(outputs, tensor_image_sizes)
|
||||||
|
|
||||||
|
check_post_processed_output(tensor_post_processed_outputs, tensor_image_sizes)
|
||||||
|
|||||||
@@ -260,7 +260,7 @@ class SuperPointModelIntegrationTest(unittest.TestCase):
|
|||||||
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
|
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
expected_number_keypoints_image0 = 567
|
expected_number_keypoints_image0 = 568
|
||||||
expected_number_keypoints_image1 = 830
|
expected_number_keypoints_image1 = 830
|
||||||
expected_max_number_keypoints = max(expected_number_keypoints_image0, expected_number_keypoints_image1)
|
expected_max_number_keypoints = max(expected_number_keypoints_image0, expected_number_keypoints_image1)
|
||||||
expected_keypoints_shape = torch.Size((len(images), expected_max_number_keypoints, 2))
|
expected_keypoints_shape = torch.Size((len(images), expected_max_number_keypoints, 2))
|
||||||
@@ -275,11 +275,13 @@ class SuperPointModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(outputs.keypoints.shape, expected_keypoints_shape)
|
self.assertEqual(outputs.keypoints.shape, expected_keypoints_shape)
|
||||||
self.assertEqual(outputs.scores.shape, expected_scores_shape)
|
self.assertEqual(outputs.scores.shape, expected_scores_shape)
|
||||||
self.assertEqual(outputs.descriptors.shape, expected_descriptors_shape)
|
self.assertEqual(outputs.descriptors.shape, expected_descriptors_shape)
|
||||||
expected_keypoints_image0_values = torch.tensor([[480.0, 9.0], [494.0, 9.0], [489.0, 16.0]]).to(torch_device)
|
expected_keypoints_image0_values = torch.tensor([[0.75, 0.0188], [0.7719, 0.0188], [0.7641, 0.0333]]).to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
expected_scores_image0_values = torch.tensor(
|
expected_scores_image0_values = torch.tensor(
|
||||||
[0.0064, 0.0137, 0.0589, 0.0723, 0.5166, 0.0174, 0.1515, 0.2054, 0.0334]
|
[0.0064, 0.0139, 0.0591, 0.0727, 0.5170, 0.0175, 0.1526, 0.2057, 0.0335]
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
expected_descriptors_image0_value = torch.tensor(-0.1096).to(torch_device)
|
expected_descriptors_image0_value = torch.tensor(-0.1095).to(torch_device)
|
||||||
predicted_keypoints_image0_values = outputs.keypoints[0, :3]
|
predicted_keypoints_image0_values = outputs.keypoints[0, :3]
|
||||||
predicted_scores_image0_values = outputs.scores[0, :9]
|
predicted_scores_image0_values = outputs.scores[0, :9]
|
||||||
predicted_descriptors_image0_value = outputs.descriptors[0, 0, 0]
|
predicted_descriptors_image0_value = outputs.descriptors[0, 0, 0]
|
||||||
|
|||||||
Reference in New Issue
Block a user