Update output of SuperPointForKeypointDetection (#29809)

* Remove auto class

* Update ImagePointDescriptionOutput

* Update model outputs

* Rename output class

* Revert "Remove auto class"

This reverts commit ed4a8f549d79cdb0cdf7aa74205a185c41471519.

* Address comments
This commit is contained in:
NielsRogge
2024-04-11 14:59:30 +02:00
committed by GitHub
parent 386ef34e7d
commit 5569552cf8
2 changed files with 23 additions and 26 deletions

View File

@@ -79,7 +79,7 @@ def simple_nms(scores: torch.Tensor, nms_radius: int) -> torch.Tensor:
@dataclass @dataclass
class ImagePointDescriptionOutput(ModelOutput): class SuperPointKeypointDescriptionOutput(ModelOutput):
""" """
Base class for outputs of image point description models. Due to the nature of keypoint detection, the number of Base class for outputs of image point description models. Due to the nature of keypoint detection, the number of
keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of images, keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of images,
@@ -88,8 +88,8 @@ class ImagePointDescriptionOutput(ModelOutput):
and which are padding. and which are padding.
Args: Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
Sequence of hidden-states at the output of the last layer of the decoder of the model. Loss computed during training.
keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`): keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
Relative (x, y) coordinates of predicted keypoints in a given image. Relative (x, y) coordinates of predicted keypoints in a given image.
scores (`torch.FloatTensor` of shape `(batch_size, num_keypoints)`): scores (`torch.FloatTensor` of shape `(batch_size, num_keypoints)`):
@@ -105,7 +105,7 @@ class ImagePointDescriptionOutput(ModelOutput):
(also called feature maps) of the model at the output of each stage. (also called feature maps) of the model at the output of each stage.
""" """
last_hidden_state: torch.FloatTensor = None loss: Optional[torch.FloatTensor] = None
keypoints: Optional[torch.IntTensor] = None keypoints: Optional[torch.IntTensor] = None
scores: Optional[torch.FloatTensor] = None scores: Optional[torch.FloatTensor] = None
descriptors: Optional[torch.FloatTensor] = None descriptors: Optional[torch.FloatTensor] = None
@@ -414,11 +414,11 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
@add_start_docstrings_to_model_forward(SUPERPOINT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(SUPERPOINT_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
pixel_values: torch.FloatTensor = None, pixel_values: torch.FloatTensor,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, ImagePointDescriptionOutput]: ) -> Union[Tuple, SuperPointKeypointDescriptionOutput]:
""" """
Examples: Examples:
@@ -437,20 +437,15 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
>>> inputs = processor(image, return_tensors="pt") >>> inputs = processor(image, return_tensors="pt")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
```""" ```"""
loss = None
if labels is not None: if labels is not None:
raise ValueError( raise ValueError("SuperPoint does not support training for now.")
f"SuperPoint is not trainable, no labels should be provided.Therefore, labels should be None but were {type(labels)}"
)
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
pixel_values = self.extract_one_channel_pixel_values(pixel_values) pixel_values = self.extract_one_channel_pixel_values(pixel_values)
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
@@ -493,12 +488,10 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
hidden_states = encoder_outputs[1] if output_hidden_states else None hidden_states = encoder_outputs[1] if output_hidden_states else None
if not return_dict: if not return_dict:
return tuple( return tuple(v for v in [loss, keypoints, scores, descriptors, mask, hidden_states] if v is not None)
v for v in [last_hidden_state, keypoints, scores, descriptors, mask, hidden_states] if v is not None
)
return ImagePointDescriptionOutput( return SuperPointKeypointDescriptionOutput(
last_hidden_state=last_hidden_state, loss=loss,
keypoints=keypoints, keypoints=keypoints,
scores=scores, scores=scores,
descriptors=descriptors, descriptors=descriptors,

View File

@@ -85,13 +85,17 @@ class SuperPointModelTester:
border_removal_distance=self.border_removal_distance, border_removal_distance=self.border_removal_distance,
) )
def create_and_check_model(self, config, pixel_values): def create_and_check_keypoint_detection(self, config, pixel_values):
model = SuperPointForKeypointDetection(config=config) model = SuperPointForKeypointDetection(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(pixel_values) result = model(pixel_values)
self.parent.assertEqual(result.keypoints.shape[0], self.batch_size)
self.parent.assertEqual(result.keypoints.shape[-1], 2)
result = model(pixel_values, output_hidden_states=True)
self.parent.assertEqual( self.parent.assertEqual(
result.last_hidden_state.shape, result.hidden_states[-1].shape,
( (
self.batch_size, self.batch_size,
self.encoder_hidden_sizes[-1], self.encoder_hidden_sizes[-1],
@@ -146,19 +150,19 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
def test_feed_forward_chunking(self): def test_feed_forward_chunking(self):
pass pass
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable") @unittest.skip(reason="SuperPointForKeypointDetection does not support training")
def test_training(self): def test_training(self):
pass pass
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable") @unittest.skip(reason="SuperPointForKeypointDetection does not support training")
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable") @unittest.skip(reason="SuperPointForKeypointDetection does not support training")
def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant(self):
pass pass
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable") @unittest.skip(reason="SuperPointForKeypointDetection does not support training")
def test_training_gradient_checkpointing_use_reentrant_false(self): def test_training_gradient_checkpointing_use_reentrant_false(self):
pass pass
@@ -166,9 +170,9 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
def test_retain_grad_hidden_states_attentions(self): def test_retain_grad_hidden_states_attentions(self):
pass pass
def test_model(self): def test_keypoint_detection(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_keypoint_detection(*config_and_inputs)
def test_forward_signature(self): def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs() config, _ = self.model_tester.prepare_config_and_inputs()