From 5569552cf8779c8951326b2fa9b7a1d64b1005c9 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Thu, 11 Apr 2024 14:59:30 +0200 Subject: [PATCH] Update output of SuperPointForKeypointDetection (#29809) * Remove auto class * Update ImagePointDescriptionOutput * Update model outputs * Rename output class * Revert "Remove auto class" This reverts commit ed4a8f549d79cdb0cdf7aa74205a185c41471519. * Address comments --- .../models/superpoint/modeling_superpoint.py | 29 +++++++------------ .../superpoint/test_modeling_superpoint.py | 20 ++++++++----- 2 files changed, 23 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/superpoint/modeling_superpoint.py b/src/transformers/models/superpoint/modeling_superpoint.py index a4350e6d79..3e3fdbbf10 100644 --- a/src/transformers/models/superpoint/modeling_superpoint.py +++ b/src/transformers/models/superpoint/modeling_superpoint.py @@ -79,7 +79,7 @@ def simple_nms(scores: torch.Tensor, nms_radius: int) -> torch.Tensor: @dataclass -class ImagePointDescriptionOutput(ModelOutput): +class SuperPointKeypointDescriptionOutput(ModelOutput): """ Base class for outputs of image point description models. Due to the nature of keypoint detection, the number of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of images, @@ -88,8 +88,8 @@ class ImagePointDescriptionOutput(ModelOutput): and which are padding. Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the decoder of the model. + loss (`torch.FloatTensor` of shape `(1,)`, *optional*): + Loss computed during training. keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`): Relative (x, y) coordinates of predicted keypoints in a given image. scores (`torch.FloatTensor` of shape `(batch_size, num_keypoints)`): @@ -105,7 +105,7 @@ class ImagePointDescriptionOutput(ModelOutput): (also called feature maps) of the model at the output of each stage. """ - last_hidden_state: torch.FloatTensor = None + loss: Optional[torch.FloatTensor] = None keypoints: Optional[torch.IntTensor] = None scores: Optional[torch.FloatTensor] = None descriptors: Optional[torch.FloatTensor] = None @@ -414,11 +414,11 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel): @add_start_docstrings_to_model_forward(SUPERPOINT_INPUTS_DOCSTRING) def forward( self, - pixel_values: torch.FloatTensor = None, + pixel_values: torch.FloatTensor, labels: Optional[torch.LongTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[Tuple, ImagePointDescriptionOutput]: + ) -> Union[Tuple, SuperPointKeypointDescriptionOutput]: """ Examples: @@ -437,20 +437,15 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel): >>> inputs = processor(image, return_tensors="pt") >>> outputs = model(**inputs) ```""" - + loss = None if labels is not None: - raise ValueError( - f"SuperPoint is not trainable, no labels should be provided.Therefore, labels should be None but were {type(labels)}" - ) + raise ValueError("SuperPoint does not support training for now.") output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - pixel_values = self.extract_one_channel_pixel_values(pixel_values) batch_size = pixel_values.shape[0] @@ -493,12 +488,10 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel): hidden_states = encoder_outputs[1] if output_hidden_states else None if not return_dict: - return tuple( - v for v in [last_hidden_state, keypoints, scores, descriptors, mask, hidden_states] if v is not None - ) + return tuple(v for v in [loss, keypoints, scores, descriptors, mask, hidden_states] if v is not None) - return ImagePointDescriptionOutput( - last_hidden_state=last_hidden_state, + return SuperPointKeypointDescriptionOutput( + loss=loss, keypoints=keypoints, scores=scores, descriptors=descriptors, diff --git a/tests/models/superpoint/test_modeling_superpoint.py b/tests/models/superpoint/test_modeling_superpoint.py index cb204d3f89..080eda385b 100644 --- a/tests/models/superpoint/test_modeling_superpoint.py +++ b/tests/models/superpoint/test_modeling_superpoint.py @@ -85,13 +85,17 @@ class SuperPointModelTester: border_removal_distance=self.border_removal_distance, ) - def create_and_check_model(self, config, pixel_values): + def create_and_check_keypoint_detection(self, config, pixel_values): model = SuperPointForKeypointDetection(config=config) model.to(torch_device) model.eval() result = model(pixel_values) + self.parent.assertEqual(result.keypoints.shape[0], self.batch_size) + self.parent.assertEqual(result.keypoints.shape[-1], 2) + + result = model(pixel_values, output_hidden_states=True) self.parent.assertEqual( - result.last_hidden_state.shape, + result.hidden_states[-1].shape, ( self.batch_size, self.encoder_hidden_sizes[-1], @@ -146,19 +150,19 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase): def test_feed_forward_chunking(self): pass - @unittest.skip(reason="SuperPointForKeypointDetection is not trainable") + @unittest.skip(reason="SuperPointForKeypointDetection does not support training") def test_training(self): pass - @unittest.skip(reason="SuperPointForKeypointDetection is not trainable") + @unittest.skip(reason="SuperPointForKeypointDetection does not support training") def test_training_gradient_checkpointing(self): pass - @unittest.skip(reason="SuperPointForKeypointDetection is not trainable") + @unittest.skip(reason="SuperPointForKeypointDetection does not support training") def test_training_gradient_checkpointing_use_reentrant(self): pass - @unittest.skip(reason="SuperPointForKeypointDetection is not trainable") + @unittest.skip(reason="SuperPointForKeypointDetection does not support training") def test_training_gradient_checkpointing_use_reentrant_false(self): pass @@ -166,9 +170,9 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase): def test_retain_grad_hidden_states_attentions(self): pass - def test_model(self): + def test_keypoint_detection(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) + self.model_tester.create_and_check_keypoint_detection(*config_and_inputs) def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs()