Update output of SuperPointForKeypointDetection (#29809)

* Remove auto class

* Update ImagePointDescriptionOutput

* Update model outputs

* Rename output class

* Revert "Remove auto class"

This reverts commit ed4a8f549d79cdb0cdf7aa74205a185c41471519.

* Address comments
This commit is contained in:
NielsRogge
2024-04-11 14:59:30 +02:00
committed by GitHub
parent 386ef34e7d
commit 5569552cf8
2 changed files with 23 additions and 26 deletions

View File

@@ -85,13 +85,17 @@ class SuperPointModelTester:
border_removal_distance=self.border_removal_distance,
)
def create_and_check_model(self, config, pixel_values):
def create_and_check_keypoint_detection(self, config, pixel_values):
model = SuperPointForKeypointDetection(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
self.parent.assertEqual(result.keypoints.shape[0], self.batch_size)
self.parent.assertEqual(result.keypoints.shape[-1], 2)
result = model(pixel_values, output_hidden_states=True)
self.parent.assertEqual(
result.last_hidden_state.shape,
result.hidden_states[-1].shape,
(
self.batch_size,
self.encoder_hidden_sizes[-1],
@@ -146,19 +150,19 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
def test_feed_forward_chunking(self):
pass
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
@unittest.skip(reason="SuperPointForKeypointDetection does not support training")
def test_training(self):
pass
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
@unittest.skip(reason="SuperPointForKeypointDetection does not support training")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
@unittest.skip(reason="SuperPointForKeypointDetection does not support training")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
@unittest.skip(reason="SuperPointForKeypointDetection does not support training")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@@ -166,9 +170,9 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
def test_retain_grad_hidden_states_attentions(self):
pass
def test_model(self):
def test_keypoint_detection(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
self.model_tester.create_and_check_keypoint_detection(*config_and_inputs)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs()