🚨🚨🚨 [SuperPoint] Fix keypoint coordinate output and add post processing (#33200)

* feat: Added int conversion and unwrapping

* test: added tests for post_process_keypoint_detection of SuperPointImageProcessor

* docs: changed docs to include post_process_keypoint_detection method and switched from opencv to matplotlib

* test: changed test to not depend on SuperPointModel forward

* test: added missing require_torch decorator

* docs: changed pyplot parameters for the keypoints to be more visible in the example

* tests: changed import torch location to make test_flax and test_tf

* Revert "tests: changed import torch location to make test_flax and test_tf"

This reverts commit 39b32a2f69500bc7af01715fc7beae2260549afe.

* tests: fixed import

* chore: applied suggestions from code review

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* tests: fixed import

* tests: fixed import (bis)

* tests: fixed import (ter)

* feat: added choice of type for target_size and changed tests accordingly

* docs: updated code snippet to reflect the addition of target size type choice in post process method

* tests: fixed imports (...)

* tests: fixed imports (...)

* style: formatting file

* docs: fixed typo from image[0] to image.size[0]

* docs: added output image and fixed some tests

* Update docs/source/en/model_doc/superpoint.md

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* fix: included SuperPointKeypointDescriptionOutput in TYPE_CHECKING if statement and changed tests results to reflect changes to SuperPoint from absolute keypoints coordinates to relative

* docs: changed SuperPoint's docs to print output instead of just accessing

* style: applied make style

* docs: added missing output type and precision in docstring of post_process_keypoint_detection

* perf: deleted loop to perform keypoint conversion in one statement

* fix: moved keypoint conversion at the end of model forward

* docs: changed SuperPointInterestPointDecoder to SuperPointKeypointDecoder class name and added relative (x, y) coordinates information to its method

* fix: changed type hint

* refactor: removed unnecessary brackets

* revert: SuperPointKeypointDecoder to SuperPointInterestPointDecoder

* Update docs/source/en/model_doc/superpoint.md

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

---------

Co-authored-by: Steven Bucaille <steven.bucaille@buawei.com>
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
StevenBucaille
2024-10-29 10:36:03 +01:00
committed by GitHub
parent 655bec2da7
commit a1835195d1
5 changed files with 146 additions and 22 deletions

View File

@@ -260,7 +260,7 @@ class SuperPointModelIntegrationTest(unittest.TestCase):
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
with torch.no_grad():
outputs = model(**inputs)
expected_number_keypoints_image0 = 567
expected_number_keypoints_image0 = 568
expected_number_keypoints_image1 = 830
expected_max_number_keypoints = max(expected_number_keypoints_image0, expected_number_keypoints_image1)
expected_keypoints_shape = torch.Size((len(images), expected_max_number_keypoints, 2))
@@ -275,11 +275,13 @@ class SuperPointModelIntegrationTest(unittest.TestCase):
self.assertEqual(outputs.keypoints.shape, expected_keypoints_shape)
self.assertEqual(outputs.scores.shape, expected_scores_shape)
self.assertEqual(outputs.descriptors.shape, expected_descriptors_shape)
expected_keypoints_image0_values = torch.tensor([[480.0, 9.0], [494.0, 9.0], [489.0, 16.0]]).to(torch_device)
expected_keypoints_image0_values = torch.tensor([[0.75, 0.0188], [0.7719, 0.0188], [0.7641, 0.0333]]).to(
torch_device
)
expected_scores_image0_values = torch.tensor(
[0.0064, 0.0137, 0.0589, 0.0723, 0.5166, 0.0174, 0.1515, 0.2054, 0.0334]
[0.0064, 0.0139, 0.0591, 0.0727, 0.5170, 0.0175, 0.1526, 0.2057, 0.0335]
).to(torch_device)
expected_descriptors_image0_value = torch.tensor(-0.1096).to(torch_device)
expected_descriptors_image0_value = torch.tensor(-0.1095).to(torch_device)
predicted_keypoints_image0_values = outputs.keypoints[0, :3]
predicted_scores_image0_values = outputs.scores[0, :9]
predicted_descriptors_image0_value = outputs.descriptors[0, 0, 0]