Update output of SuperPointForKeypointDetection (#29809)
* Remove auto class * Update ImagePointDescriptionOutput * Update model outputs * Rename output class * Revert "Remove auto class" This reverts commit ed4a8f549d79cdb0cdf7aa74205a185c41471519. * Address comments
This commit is contained in:
@@ -85,13 +85,17 @@ class SuperPointModelTester:
|
||||
border_removal_distance=self.border_removal_distance,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values):
|
||||
def create_and_check_keypoint_detection(self, config, pixel_values):
|
||||
model = SuperPointForKeypointDetection(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
self.parent.assertEqual(result.keypoints.shape[0], self.batch_size)
|
||||
self.parent.assertEqual(result.keypoints.shape[-1], 2)
|
||||
|
||||
result = model(pixel_values, output_hidden_states=True)
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape,
|
||||
result.hidden_states[-1].shape,
|
||||
(
|
||||
self.batch_size,
|
||||
self.encoder_hidden_sizes[-1],
|
||||
@@ -146,19 +150,19 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection does not support training")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection does not support training")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection does not support training")
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection does not support training")
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@@ -166,9 +170,9 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
def test_model(self):
|
||||
def test_keypoint_detection(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
self.model_tester.create_and_check_keypoint_detection(*config_and_inputs)
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
Reference in New Issue
Block a user